summaryrefslogtreecommitdiff
path: root/ggml
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-07-27 07:55:01 +0200
committerGitHub <noreply@github.com>2024-07-27 07:55:01 +0200
commit154e0d75fccf1784fe9ff6fd76a630b66563da3d (patch)
tree81ce6dbb5b1900c1aa78a879f0593c694cab9d27 /ggml
parent0684c3e9c70d49323b4fc517128cbe222cab7f96 (diff)
Merge mainline llama.cpp (#3)
* Merging mainline - WIP * Merging mainline - WIP AVX2 and CUDA appear to work. CUDA performance seems slightly (~1-2%) lower as it is so often the case with llama.cpp/ggml after some "improvements" have been made. * Merging mainline - fix Metal * Remove check --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml')
-rw-r--r--ggml/.gitignore2
-rw-r--r--ggml/CMakeLists.txt243
-rw-r--r--ggml/cmake/FindSIMD.cmake100
-rw-r--r--ggml/include/ggml-alloc.h76
-rw-r--r--ggml/include/ggml-backend.h238
-rw-r--r--ggml/include/ggml-blas.h23
-rw-r--r--ggml/include/ggml-cann.h125
-rw-r--r--ggml/include/ggml-cuda.h44
-rw-r--r--ggml/include/ggml-kompute.h46
-rw-r--r--ggml/include/ggml-metal.h65
-rw-r--r--ggml/include/ggml-rpc.h24
-rw-r--r--ggml/include/ggml-sycl.h42
-rw-r--r--ggml/include/ggml-vulkan.h29
-rw-r--r--ggml/include/ggml.h2453
-rw-r--r--ggml/src/CMakeLists.txt1291
-rw-r--r--ggml/src/ggml-aarch64.c2193
-rw-r--r--ggml/src/ggml-aarch64.h39
-rw-r--r--ggml/src/ggml-alloc.c1042
-rw-r--r--ggml/src/ggml-backend-impl.h153
-rw-r--r--ggml/src/ggml-backend.c2234
-rw-r--r--ggml/src/ggml-blas.cpp368
-rw-r--r--ggml/src/ggml-cann.cpp2023
-rw-r--r--ggml/src/ggml-cann/.clang-format168
-rw-r--r--ggml/src/ggml-cann/Doxyfile2579
-rw-r--r--ggml/src/ggml-cann/acl_tensor.cpp198
-rw-r--r--ggml/src/ggml-cann/acl_tensor.h230
-rw-r--r--ggml/src/ggml-cann/aclnn_ops.cpp2944
-rw-r--r--ggml/src/ggml-cann/aclnn_ops.h592
-rw-r--r--ggml/src/ggml-cann/common.h282
-rw-r--r--ggml/src/ggml-cann/kernels/CMakeLists.txt32
-rw-r--r--ggml/src/ggml-cann/kernels/ascendc_kernels.h17
-rw-r--r--ggml/src/ggml-cann/kernels/dup.cpp223
-rw-r--r--ggml/src/ggml-cann/kernels/get_row_f16.cpp186
-rw-r--r--ggml/src/ggml-cann/kernels/get_row_f32.cpp180
-rw-r--r--ggml/src/ggml-cann/kernels/get_row_q4_0.cpp193
-rw-r--r--ggml/src/ggml-cann/kernels/get_row_q8_0.cpp191
-rw-r--r--ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp208
-rw-r--r--ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp206
-rw-r--r--ggml/src/ggml-common.h1880
-rw-r--r--ggml/src/ggml-cuda.cu3079
-rw-r--r--ggml/src/ggml-cuda/acc.cu47
-rw-r--r--ggml/src/ggml-cuda/acc.cuh5
-rw-r--r--ggml/src/ggml-cuda/arange.cu34
-rw-r--r--ggml/src/ggml-cuda/arange.cuh5
-rw-r--r--ggml/src/ggml-cuda/argsort.cu104
-rw-r--r--ggml/src/ggml-cuda/argsort.cuh3
-rw-r--r--ggml/src/ggml-cuda/binbcast.cu316
-rw-r--r--ggml/src/ggml-cuda/binbcast.cuh6
-rw-r--r--ggml/src/ggml-cuda/clamp.cu34
-rw-r--r--ggml/src/ggml-cuda/clamp.cuh5
-rw-r--r--ggml/src/ggml-cuda/common.cuh871
-rw-r--r--ggml/src/ggml-cuda/concat.cu196
-rw-r--r--ggml/src/ggml-cuda/concat.cuh5
-rw-r--r--ggml/src/ggml-cuda/conv-transpose-1d.cu87
-rw-r--r--ggml/src/ggml-cuda/conv-transpose-1d.cuh5
-rw-r--r--ggml/src/ggml-cuda/convert.cu775
-rw-r--r--ggml/src/ggml-cuda/convert.cuh13
-rw-r--r--ggml/src/ggml-cuda/cpy.cu489
-rw-r--r--ggml/src/ggml-cuda/cpy.cuh9
-rw-r--r--ggml/src/ggml-cuda/dequantize.cuh103
-rw-r--r--ggml/src/ggml-cuda/diagmask.cu40
-rw-r--r--ggml/src/ggml-cuda/diagmask.cuh5
-rw-r--r--ggml/src/ggml-cuda/dmmv.cu674
-rw-r--r--ggml/src/ggml-cuda/dmmv.cuh18
-rw-r--r--ggml/src/ggml-cuda/fattn-common.cuh701
-rw-r--r--ggml/src/ggml-cuda/fattn-tile-f16.cu319
-rw-r--r--ggml/src/ggml-cuda/fattn-tile-f16.cuh3
-rw-r--r--ggml/src/ggml-cuda/fattn-tile-f32.cu312
-rw-r--r--ggml/src/ggml-cuda/fattn-tile-f32.cuh3
-rw-r--r--ggml/src/ggml-cuda/fattn-vec-f16.cuh397
-rw-r--r--ggml/src/ggml-cuda/fattn-vec-f32.cuh374
-rw-r--r--ggml/src/ggml-cuda/fattn-wmma-f16.cuh490
-rw-r--r--ggml/src/ggml-cuda/fattn.cu345
-rw-r--r--ggml/src/ggml-cuda/fattn.cuh3
-rw-r--r--ggml/src/ggml-cuda/getrows.cu178
-rw-r--r--ggml/src/ggml-cuda/getrows.cuh5
-rw-r--r--ggml/src/ggml-cuda/im2col.cu104
-rw-r--r--ggml/src/ggml-cuda/im2col.cuh5
-rw-r--r--ggml/src/ggml-cuda/mma.cuh221
-rw-r--r--ggml/src/ggml-cuda/mmq.cu150
-rw-r--r--ggml/src/ggml-cuda/mmq.cuh2936
-rw-r--r--ggml/src/ggml-cuda/mmvq.cu447
-rw-r--r--ggml/src/ggml-cuda/mmvq.cuh9
-rw-r--r--ggml/src/ggml-cuda/norm.cu221
-rw-r--r--ggml/src/ggml-cuda/norm.cuh7
-rw-r--r--ggml/src/ggml-cuda/pad.cu49
-rw-r--r--ggml/src/ggml-cuda/pad.cuh5
-rw-r--r--ggml/src/ggml-cuda/pool2d.cu94
-rw-r--r--ggml/src/ggml-cuda/pool2d.cuh5
-rw-r--r--ggml/src/ggml-cuda/quantize.cu169
-rw-r--r--ggml/src/ggml-cuda/quantize.cuh24
-rw-r--r--ggml/src/ggml-cuda/rope.cu271
-rw-r--r--ggml/src/ggml-cuda/rope.cuh5
-rw-r--r--ggml/src/ggml-cuda/scale.cu31
-rw-r--r--ggml/src/ggml-cuda/scale.cuh5
-rw-r--r--ggml/src/ggml-cuda/softmax.cu206
-rw-r--r--ggml/src/ggml-cuda/softmax.cuh5
-rw-r--r--ggml/src/ggml-cuda/sumrows.cu40
-rw-r--r--ggml/src/ggml-cuda/sumrows.cuh3
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu10
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu9
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu10
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu10
-rw-r--r--ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu8
-rwxr-xr-xggml/src/ggml-cuda/template-instances/generate_cu_files.py77
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu5
-rw-r--r--ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu5
-rw-r--r--ggml/src/ggml-cuda/tsembd.cu47
-rw-r--r--ggml/src/ggml-cuda/tsembd.cuh5
-rw-r--r--ggml/src/ggml-cuda/unary.cu314
-rw-r--r--ggml/src/ggml-cuda/unary.cuh33
-rw-r--r--ggml/src/ggml-cuda/upscale.cu51
-rw-r--r--ggml/src/ggml-cuda/upscale.cuh5
-rw-r--r--ggml/src/ggml-cuda/vecdotq.cuh1229
-rw-r--r--ggml/src/ggml-impl.h655
-rw-r--r--ggml/src/ggml-kompute.cpp2038
-rw-r--r--ggml/src/ggml-metal.m3380
-rw-r--r--ggml/src/ggml-metal.metal6563
-rw-r--r--ggml/src/ggml-quants.c14976
-rw-r--r--ggml/src/ggml-quants.h151
-rw-r--r--ggml/src/ggml-rpc.cpp1178
-rw-r--r--ggml/src/ggml-sycl.cpp5314
-rw-r--r--ggml/src/ggml-sycl/backend.hpp27
-rw-r--r--ggml/src/ggml-sycl/common.cpp53
-rw-r--r--ggml/src/ggml-sycl/common.hpp355
-rw-r--r--ggml/src/ggml-sycl/concat.cpp195
-rw-r--r--ggml/src/ggml-sycl/concat.hpp21
-rw-r--r--ggml/src/ggml-sycl/convert.cpp547
-rw-r--r--ggml/src/ggml-sycl/convert.hpp27
-rw-r--r--ggml/src/ggml-sycl/dequantize.hpp698
-rw-r--r--ggml/src/ggml-sycl/dmmv.cpp1023
-rw-r--r--ggml/src/ggml-sycl/dmmv.hpp27
-rw-r--r--ggml/src/ggml-sycl/dpct/helper.hpp3011
-rw-r--r--ggml/src/ggml-sycl/mmq.cpp3031
-rw-r--r--ggml/src/ggml-sycl/mmq.hpp33
-rw-r--r--ggml/src/ggml-sycl/mmvq.cpp1027
-rw-r--r--ggml/src/ggml-sycl/mmvq.hpp27
-rw-r--r--ggml/src/ggml-sycl/norm.cpp374
-rw-r--r--ggml/src/ggml-sycl/norm.hpp35
-rw-r--r--ggml/src/ggml-sycl/presets.hpp66
-rw-r--r--ggml/src/ggml-sycl/rope.cpp275
-rw-r--r--ggml/src/ggml-sycl/rope.hpp22
-rw-r--r--ggml/src/ggml-sycl/softmax.cpp251
-rw-r--r--ggml/src/ggml-sycl/softmax.hpp24
-rw-r--r--ggml/src/ggml-sycl/vecdotq.hpp1140
-rw-r--r--ggml/src/ggml-vulkan.cpp7022
-rw-r--r--ggml/src/ggml.c22196
-rw-r--r--ggml/src/iqk/iqk_mul_mat.cpp4757
-rw-r--r--ggml/src/iqk/iqk_mul_mat.h27
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp414
m---------ggml/src/kompute0
-rw-r--r--ggml/src/kompute-shaders/common.comp102
-rw-r--r--ggml/src/kompute-shaders/op_add.comp58
-rw-r--r--ggml/src/kompute-shaders/op_addrow.comp25
-rw-r--r--ggml/src/kompute-shaders/op_cpy_f16_f16.comp52
-rw-r--r--ggml/src/kompute-shaders/op_cpy_f16_f32.comp52
-rw-r--r--ggml/src/kompute-shaders/op_cpy_f32_f16.comp52
-rw-r--r--ggml/src/kompute-shaders/op_cpy_f32_f32.comp52
-rw-r--r--ggml/src/kompute-shaders/op_diagmask.comp30
-rw-r--r--ggml/src/kompute-shaders/op_gelu.comp22
-rw-r--r--ggml/src/kompute-shaders/op_getrows.comp17
-rw-r--r--ggml/src/kompute-shaders/op_getrows_f16.comp31
-rw-r--r--ggml/src/kompute-shaders/op_getrows_f32.comp31
-rw-r--r--ggml/src/kompute-shaders/op_getrows_q4_0.comp38
-rw-r--r--ggml/src/kompute-shaders/op_getrows_q4_1.comp39
-rw-r--r--ggml/src/kompute-shaders/op_getrows_q6_k.comp44
-rw-r--r--ggml/src/kompute-shaders/op_mul.comp52
-rw-r--r--ggml/src/kompute-shaders/op_mul_mat_f16.comp67
-rw-r--r--ggml/src/kompute-shaders/op_mul_mat_mat_f32.comp51
-rw-r--r--ggml/src/kompute-shaders/op_mul_mat_q4_0.comp33
-rw-r--r--ggml/src/kompute-shaders/op_mul_mat_q4_1.comp35
-rw-r--r--ggml/src/kompute-shaders/op_mul_mat_q6_k.comp94
-rw-r--r--ggml/src/kompute-shaders/op_mul_mat_q8_0.comp73
-rw-r--r--ggml/src/kompute-shaders/op_mul_mv_q_n.comp48
-rw-r--r--ggml/src/kompute-shaders/op_mul_mv_q_n_pre.comp22
-rw-r--r--ggml/src/kompute-shaders/op_norm.comp84
-rw-r--r--ggml/src/kompute-shaders/op_relu.comp21
-rw-r--r--ggml/src/kompute-shaders/op_rmsnorm.comp53
-rw-r--r--ggml/src/kompute-shaders/op_rope_f16.comp73
-rw-r--r--ggml/src/kompute-shaders/op_rope_f32.comp73
-rw-r--r--ggml/src/kompute-shaders/op_scale.comp19
-rw-r--r--ggml/src/kompute-shaders/op_scale_8.comp23
-rw-r--r--ggml/src/kompute-shaders/op_silu.comp22
-rw-r--r--ggml/src/kompute-shaders/op_softmax.comp56
-rw-r--r--ggml/src/kompute-shaders/rope_common.comp67
-rw-r--r--ggml/src/llamafile/sgemm.cpp1028
-rw-r--r--ggml/src/llamafile/sgemm.h14
-rw-r--r--ggml/src/vulkan-shaders/CMakeLists.txt5
-rw-r--r--ggml/src/vulkan-shaders/add.comp12
-rw-r--r--ggml/src/vulkan-shaders/argsort.comp71
-rw-r--r--ggml/src/vulkan-shaders/clamp.comp13
-rw-r--r--ggml/src/vulkan-shaders/copy.comp16
-rw-r--r--ggml/src/vulkan-shaders/dequant_f32.comp20
-rw-r--r--ggml/src/vulkan-shaders/dequant_funcs.comp68
-rw-r--r--ggml/src/vulkan-shaders/dequant_head.comp13
-rw-r--r--ggml/src/vulkan-shaders/dequant_iq4_nl.comp30
-rw-r--r--ggml/src/vulkan-shaders/dequant_q2_k.comp34
-rw-r--r--ggml/src/vulkan-shaders/dequant_q3_k.comp42
-rw-r--r--ggml/src/vulkan-shaders/dequant_q4_0.comp30
-rw-r--r--ggml/src/vulkan-shaders/dequant_q4_1.comp32
-rw-r--r--ggml/src/vulkan-shaders/dequant_q4_k.comp56
-rw-r--r--ggml/src/vulkan-shaders/dequant_q5_0.comp34
-rw-r--r--ggml/src/vulkan-shaders/dequant_q5_1.comp35
-rw-r--r--ggml/src/vulkan-shaders/dequant_q5_k.comp58
-rw-r--r--ggml/src/vulkan-shaders/dequant_q6_k.comp33
-rw-r--r--ggml/src/vulkan-shaders/dequant_q8_0.comp31
-rw-r--r--ggml/src/vulkan-shaders/diag_mask_inf.comp34
-rw-r--r--ggml/src/vulkan-shaders/div.comp12
-rw-r--r--ggml/src/vulkan-shaders/gelu.comp25
-rw-r--r--ggml/src/vulkan-shaders/generic_binary_head.comp48
-rw-r--r--ggml/src/vulkan-shaders/generic_head.comp9
-rw-r--r--ggml/src/vulkan-shaders/generic_unary_head.comp35
-rw-r--r--ggml/src/vulkan-shaders/get_rows.comp26
-rw-r--r--ggml/src/vulkan-shaders/get_rows_quant.comp31
-rw-r--r--ggml/src/vulkan-shaders/mul.comp12
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_split_k_reduce.comp29
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec.comp50
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_base.comp81
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_nc.comp71
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_p021.comp73
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp73
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp66
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp115
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp111
-rw-r--r--ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp79
-rw-r--r--ggml/src/vulkan-shaders/mul_mm.comp507
-rw-r--r--ggml/src/vulkan-shaders/norm.comp44
-rw-r--r--ggml/src/vulkan-shaders/relu.comp21
-rw-r--r--ggml/src/vulkan-shaders/rms_norm.comp42
-rw-r--r--ggml/src/vulkan-shaders/rope_head.comp44
-rw-r--r--ggml/src/vulkan-shaders/rope_neox.comp37
-rw-r--r--ggml/src/vulkan-shaders/rope_norm.comp37
-rw-r--r--ggml/src/vulkan-shaders/scale.comp12
-rw-r--r--ggml/src/vulkan-shaders/silu.comp22
-rw-r--r--ggml/src/vulkan-shaders/soft_max.comp106
-rw-r--r--ggml/src/vulkan-shaders/square.comp13
-rw-r--r--ggml/src/vulkan-shaders/sum_rows.comp37
-rw-r--r--ggml/src/vulkan-shaders/types.comp200
-rw-r--r--ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp525
341 files changed, 127433 insertions, 0 deletions
diff --git a/ggml/.gitignore b/ggml/.gitignore
new file mode 100644
index 00000000..c82d8e69
--- /dev/null
+++ b/ggml/.gitignore
@@ -0,0 +1,2 @@
+src/ggml-vulkan-shaders.hpp
+src/ggml-vulkan-shaders.cpp
diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt
new file mode 100644
index 00000000..66505753
--- /dev/null
+++ b/ggml/CMakeLists.txt
@@ -0,0 +1,243 @@
+cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
+project("ggml" C CXX)
+include(CheckIncludeFileCXX)
+
+set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+
+if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
+ set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
+ set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
+endif()
+
+if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
+ set(GGML_STANDALONE ON)
+
+ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
+
+ # configure project version
+ # TODO
+else()
+ set(GGML_STANDALONE OFF)
+endif()
+
+if (EMSCRIPTEN)
+ set(BUILD_SHARED_LIBS_DEFAULT OFF)
+
+ option(GGML_WASM_SINGLE_FILE "ggml: embed WASM inside the generated ggml.js" ON)
+else()
+ if (MINGW)
+ set(BUILD_SHARED_LIBS_DEFAULT OFF)
+ else()
+ set(BUILD_SHARED_LIBS_DEFAULT ON)
+ endif()
+endif()
+
+option(BUILD_SHARED_LIBS "ggml: build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT})
+
+#
+# option list
+#
+
+# TODO: mark all options as advanced when not GGML_STANDALONE
+
+if (APPLE)
+ set(GGML_METAL_DEFAULT ON)
+ set(GGML_BLAS_DEFAULT ON)
+ set(GGML_BLAS_VENDOR_DEFAULT "Apple")
+else()
+ set(GGML_METAL_DEFAULT OFF)
+ set(GGML_BLAS_DEFAULT OFF)
+ set(GGML_BLAS_VENDOR_DEFAULT "Generic")
+endif()
+
+# general
+option(GGML_STATIC "ggml: static link libraries" OFF)
+option(GGML_NATIVE "ggml: enable -march=native flag" ON)
+option(GGML_LTO "ggml: enable link time optimization" OFF)
+option(GGML_CCACHE "ggml: use ccache if available" ON)
+
+# debug
+option(GGML_ALL_WARNINGS "ggml: enable all compiler warnings" ON)
+option(GGML_ALL_WARNINGS_3RD_PARTY "ggml: enable all compiler warnings in 3rd party libs" OFF)
+option(GGML_GPROF "ggml: enable gprof" OFF)
+
+# build
+option(GGML_FATAL_WARNINGS "ggml: enable -Werror flag" OFF)
+
+# sanitizers
+option(GGML_SANITIZE_THREAD "ggml: enable thread sanitizer" OFF)
+option(GGML_SANITIZE_ADDRESS "ggml: enable address sanitizer" OFF)
+option(GGML_SANITIZE_UNDEFINED "ggml: enable undefined sanitizer" OFF)
+
+# instruction set specific
+if (GGML_NATIVE)
+ set(INS_ENB OFF)
+else()
+ set(INS_ENB ON)
+endif()
+
+option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
+
+option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
+option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
+option(GGML_AVX512 "ggml: enable AVX512" OFF)
+option(GGML_AVX512_VBMI "ggml: enable AVX512-VBMI" OFF)
+option(GGML_AVX512_VNNI "ggml: enable AVX512-VNNI" OFF)
+option(GGML_AVX512_BF16 "ggml: enable AVX512-BF16" OFF)
+option(GGML_FMA "ggml: enable FMA" ${INS_ENB})
+if (NOT MSVC)
+ option(GGML_F16C "ggml: enable F16C" ${INS_ENB}) # in MSVC F16C is implied with AVX2/AVX512
+endif()
+option(GGML_LASX "ggml: enable lasx" ON)
+option(GGML_LSX "ggml: enable lsx" ON)
+option(GGML_SVE "ggml: enable SVE" OFF)
+
+if (WIN32)
+ set(GGML_WIN_VER "0x602" CACHE STRING "ggml: Windows Version")
+endif()
+
+# ggml core
+set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
+
+# 3rd party libs / backends
+option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON)
+option(GGML_BLAS "ggml: use BLAS" ${GGML_BLAS_DEFAULT})
+set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
+ "ggml: BLAS library vendor")
+option(GGML_LLAMAFILE "ggml: use LLAMAFILE" OFF)
+option(GGML_IQK_MUL_MAT "ggml: use optimized iqk matrix multiplications" ON)
+
+option(GGML_CUDA "ggml: use CUDA" OFF)
+option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)
+option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF)
+option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF)
+set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels")
+set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels")
+option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF)
+set (GGML_CUDA_KQUANTS_ITER "2" CACHE STRING
+ "ggml: iters./thread per block for Q2_K/Q6_K")
+set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
+ "ggml: max. batch size for using peer access")
+option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
+option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
+option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
+option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" OFF)
+
+option(GGML_CURL "ggml: use libcurl to download model from an URL" OFF)
+option(GGML_HIPBLAS "ggml: use hipBLAS" OFF)
+option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
+option(GGML_VULKAN "ggml: use Vulkan" OFF)
+option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
+option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF)
+option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug output" OFF)
+option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF)
+option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
+option(GGML_KOMPUTE "ggml: use Kompute" OFF)
+option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
+option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
+option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
+option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL})
+set (GGML_METAL_MACOSX_VERSION_MIN "" CACHE STRING
+ "ggml: metal minimum macOS version")
+set (GGML_METAL_STD "" CACHE STRING "ggml: metal standard version (-std flag)")
+option(GGML_OPENMP "ggml: use OpenMP" ON)
+option(GGML_RPC "ggml: use RPC" OFF)
+option(GGML_SYCL "ggml: use SYCL" OFF)
+option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
+set (GGML_SYCL_TARGET "INTEL" CACHE STRING
+ "ggml: sycl target device")
+
+# extra artifacts
+option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE})
+option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})
+
+#
+# dependencies
+#
+
+set(CMAKE_C_STANDARD 11)
+set(CMAKE_C_STANDARD_REQUIRED true)
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED true)
+
+set(THREADS_PREFER_PTHREAD_FLAG ON)
+
+find_package(Threads REQUIRED)
+
+#
+# build the library
+#
+
+add_subdirectory(src)
+
+#
+# tests and examples
+#
+
+if (GGML_BUILD_TESTS)
+ enable_testing()
+ add_subdirectory(tests)
+endif ()
+
+if (GGML_BUILD_EXAMPLES)
+ add_subdirectory(examples)
+endif ()
+
+#
+# install
+#
+
+include(GNUInstallDirs)
+include(CMakePackageConfigHelpers)
+
+# all public headers
+set(GGML_PUBLIC_HEADERS
+ include/ggml.h
+ include/ggml-alloc.h
+ include/ggml-backend.h
+ include/ggml-blas.h
+ include/ggml-cuda.h
+ include/ggml.h
+ include/ggml-kompute.h
+ include/ggml-metal.h
+ include/ggml-rpc.h
+ include/ggml-sycl.h
+ include/ggml-vulkan.h)
+
+set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
+#if (GGML_METAL)
+# set_target_properties(ggml PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal")
+#endif()
+install(TARGETS ggml PUBLIC_HEADER)
+
+if (BUILD_SHARED_LIBS)
+ install(TARGETS ggml LIBRARY)
+endif()
+
+if (GGML_METAL)
+ install(
+ FILES src/ggml-metal.metal
+ PERMISSIONS
+ OWNER_READ
+ OWNER_WRITE
+ GROUP_READ
+ WORLD_READ
+ DESTINATION ${CMAKE_INSTALL_BINDIR})
+
+ if (NOT GGML_METAL_EMBED_LIBRARY)
+ install(
+ FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
+ DESTINATION ${CMAKE_INSTALL_BINDIR}
+ )
+ endif()
+endif()
+
+if (GGML_STANDALONE)
+ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ggml.pc.in
+ ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
+ @ONLY)
+
+ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
+ DESTINATION share/pkgconfig)
+endif()
diff --git a/ggml/cmake/FindSIMD.cmake b/ggml/cmake/FindSIMD.cmake
new file mode 100644
index 00000000..5533668e
--- /dev/null
+++ b/ggml/cmake/FindSIMD.cmake
@@ -0,0 +1,100 @@
+include(CheckCSourceRuns)
+
+set(AVX_CODE "
+ #include <immintrin.h>
+ int main()
+ {
+ __m256 a;
+ a = _mm256_set1_ps(0);
+ return 0;
+ }
+")
+
+set(AVX512_CODE "
+ #include <immintrin.h>
+ int main()
+ {
+ __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0);
+ __m512i b = a;
+ __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
+ return 0;
+ }
+")
+
+set(AVX2_CODE "
+ #include <immintrin.h>
+ int main()
+ {
+ __m256i a = {0};
+ a = _mm256_abs_epi16(a);
+ __m256i x;
+ _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
+ return 0;
+ }
+")
+
+set(FMA_CODE "
+ #include <immintrin.h>
+ int main()
+ {
+ __m256 acc = _mm256_setzero_ps();
+ const __m256 d = _mm256_setzero_ps();
+ const __m256 p = _mm256_setzero_ps();
+ acc = _mm256_fmadd_ps( d, p, acc );
+ return 0;
+ }
+")
+
+macro(check_sse type flags)
+ set(__FLAG_I 1)
+ set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
+ foreach (__FLAG ${flags})
+ if (NOT ${type}_FOUND)
+ set(CMAKE_REQUIRED_FLAGS ${__FLAG})
+ check_c_source_runs("${${type}_CODE}" HAS_${type}_${__FLAG_I})
+ if (HAS_${type}_${__FLAG_I})
+ set(${type}_FOUND TRUE CACHE BOOL "${type} support")
+ set(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags")
+ endif()
+ math(EXPR __FLAG_I "${__FLAG_I}+1")
+ endif()
+ endforeach()
+ set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
+
+ if (NOT ${type}_FOUND)
+ set(${type}_FOUND FALSE CACHE BOOL "${type} support")
+ set(${type}_FLAGS "" CACHE STRING "${type} flags")
+ endif()
+
+ mark_as_advanced(${type}_FOUND ${type}_FLAGS)
+endmacro()
+
+# flags are for MSVC only!
+check_sse("AVX" " ;/arch:AVX")
+if (NOT ${AVX_FOUND})
+ set(GGML_AVX OFF)
+else()
+ set(GGML_AVX ON)
+endif()
+
+check_sse("AVX2" " ;/arch:AVX2")
+check_sse("FMA" " ;/arch:AVX2")
+if ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))
+ set(GGML_AVX2 OFF)
+else()
+ set(GGML_AVX2 ON)
+endif()
+
+check_sse("AVX512" " ;/arch:AVX512")
+if (NOT ${AVX512_FOUND})
+ set(GGML_AVX512 OFF)
+else()
+ set(GGML_AVX512 ON)
+endif()
diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h
new file mode 100644
index 00000000..434c13b3
--- /dev/null
+++ b/ggml/include/ggml-alloc.h
@@ -0,0 +1,76 @@
+#pragma once
+
+#include "ggml.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
+typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+typedef struct ggml_backend * ggml_backend_t;
+
+// Tensor allocator
+struct ggml_tallocr {
+ ggml_backend_buffer_t buffer;
+ void * base;
+ size_t alignment;
+ size_t offset;
+};
+
+GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer);
+GGML_API void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor);
+
+// Graph allocator
+/*
+ Example usage:
+ ggml_gallocr_t galloc = ggml_gallocr_new(ggml_bacckend_cpu_buffer_type());
+
+ // optional: create a worst-case graph and reserve the buffers to avoid reallocations
+ ggml_gallocr_reserve(galloc, build_graph(max_batch));
+
+ // allocate the graph
+ struct ggml_cgraph * graph = build_graph(batch);
+ ggml_gallocr_alloc_graph(galloc, graph);
+
+ printf("compute buffer size: %zu bytes\n", ggml_gallocr_get_buffer_size(galloc, 0));
+
+ // evaluate the graph
+ ggml_backend_graph_compute(backend, graph);
+*/
+
+// special tensor flags for use with the graph allocator:
+// ggml_set_input(): all input tensors are allocated at the beginning of the graph in non-overlapping addresses
+// ggml_set_output(): output tensors are never freed and never overwritten
+
+typedef struct ggml_gallocr * ggml_gallocr_t;
+
+GGML_API ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft);
+GGML_API ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs);
+GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc);
+
+// pre-allocate buffers from a measure graph - does not allocate or modify the graph
+// call with a worst-case graph to avoid buffer reallocations
+// not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed
+// returns false if the buffer allocation failed
+GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
+GGML_API bool ggml_gallocr_reserve_n(
+ ggml_gallocr_t galloc,
+ struct ggml_cgraph * graph,
+ const int * node_buffer_ids,
+ const int * leaf_buffer_ids);
+
+// automatic reallocation if the topology changes when using a single buffer
+// returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers)
+GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
+
+GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
+
+// Utils
+// Create a buffer and allocate all the tensors in a ggml_context
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
new file mode 100644
index 00000000..5f3f1e28
--- /dev/null
+++ b/ggml/include/ggml-backend.h
@@ -0,0 +1,238 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-alloc.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+ typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
+ typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+ typedef struct ggml_backend_event * ggml_backend_event_t;
+ typedef struct ggml_backend * ggml_backend_t;
+ typedef void * ggml_backend_graph_plan_t;
+
+ //
+ // Backend buffer
+ //
+
+ // buffer type
+ GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft);
+ GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size);
+ GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
+ GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft);
+ GGML_API GGML_CALL size_t ggml_backend_buft_get_alloc_size (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
+ GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft);
+
+ // buffer
+ enum ggml_backend_buffer_usage {
+ GGML_BACKEND_BUFFER_USAGE_ANY = 0,
+ GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1,
+ GGML_BACKEND_BUFFER_USAGE_COMPUTE = 2,
+ };
+
+ GGML_API const char * ggml_backend_buffer_name (ggml_backend_buffer_t buffer);
+ GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
+ GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
+ GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
+ GGML_API GGML_CALL void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+ GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
+ GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer);
+ GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+ GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value);
+ GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer);
+ GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
+ GGML_API enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage (ggml_backend_buffer_t buffer);
+ GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_get_type (ggml_backend_buffer_t buffer);
+ GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer);
+
+ //
+ // Backend
+ //
+
+ GGML_API ggml_guid_t ggml_backend_guid(ggml_backend_t backend);
+ GGML_API const char * ggml_backend_name(ggml_backend_t backend);
+ GGML_API void ggml_backend_free(ggml_backend_t backend);
+
+ GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
+ GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
+ GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
+ GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend);
+
+ GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+ GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
+
+ GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+ GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
+
+ GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
+
+ GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
+ GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+
+ GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+ GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+ GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
+ GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
+ GGML_API bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
+ GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op);
+
+ // tensor copy between different backends
+ GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
+
+ // asynchronous copy
+ // the copy is performed after all the currently queued operations in backend_src
+ // backend_dst will wait for the copy to complete before performing other operations
+ // automatic fallback to sync copy if async is not supported
+ GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);
+
+ // events
+ GGML_API ggml_backend_event_t ggml_backend_event_new (ggml_backend_t backend);
+ GGML_API void ggml_backend_event_free (ggml_backend_event_t event);
+ GGML_API void ggml_backend_event_record (ggml_backend_event_t event);
+ GGML_API void ggml_backend_event_synchronize(ggml_backend_event_t event);
+ GGML_API void ggml_backend_event_wait (ggml_backend_t backend, ggml_backend_event_t event);
+
+ //
+ // CPU backend
+ //
+
+ GGML_API ggml_backend_t ggml_backend_cpu_init(void);
+
+ GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend);
+ GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
+ GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
+
+ // Create a backend buffer from an existing pointer
+ GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
+
+ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
+
+#ifdef GGML_USE_CPU_HBM
+ GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
+#endif
+
+ //
+ // Backend registry
+ //
+
+ // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
+
+ GGML_API size_t ggml_backend_reg_get_count(void);
+ GGML_API size_t ggml_backend_reg_find_by_name(const char * name);
+ GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional)
+ GGML_API const char * ggml_backend_reg_get_name(size_t i);
+ GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
+ GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
+ GGML_API ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size);
+
+ //
+ // Backend scheduler
+ //
+
+ // The backend scheduler allows for multiple backends to be used together
+ // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends
+ // The backends are selected based on:
+ // - the backend that supports the operation
+ // - the location of the pre-allocated tensors (e.g. the weights)
+ /*
+ Example usage:
+
+ // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned
+ // preferrably to run on the same backend as the buffer
+ ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+
+ sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false);
+
+ // initialize buffers from a max size graph (optional)
+ reserve_graph = build_graph(sched, max_batch_size);
+
+ // manually assign nodes to a backend (optional, should not be needed in most cases)
+ struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
+ ggml_backend_sched_set_tensor_backend(sched, node, backend_gpu);
+
+ ggml_backend_sched_reserve(sched, reserve_graph);
+
+ // compute
+ graph = build_graph(sched);
+ ggml_backend_sched_graph_compute(sched, graph);
+
+ // if there are graph inputs:
+ ggml_backend_sched_reset(sched);
+ ggml_backend_sched_alloc_graph(sched, graph);
+ ggml_backend_tensor_set(input_tensor, ...);
+ ggml_backend_sched_graph_compute(sched, graph);
+ }
+ */
+
+ struct ggml_backend_sched;
+ typedef struct ggml_backend_sched * ggml_backend_sched_t;
+
+ // when ask == true, the scheduler wants to know if the user wants to observe this node
+ // this allows the scheduler to batch nodes together in order to evaluate them in a single call
+ //
+ // when ask == false, the scheduler is passing the node tensor to the user for observation
+ // if the user returns false, the scheduler will cancel the graph compute
+ //
+ typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
+
+ // Initialize a backend scheduler
+ GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
+ GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
+
+ // Initialize backend buffers from a measure graph
+ GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
+
+ GGML_API int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
+ GGML_API ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
+
+ // Get the number of splits of the last graph
+ GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
+ GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);
+
+ GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
+
+ GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
+ GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
+
+ // Allocate and compute graph on the backend scheduler
+ GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+ GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+ GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+ GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
+
+ // Reset all assignments and allocators - must be called before changing the node backends
+ GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
+
+ // Set a callback to be called for each resulting node during graph compute
+ GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
+
+ //
+ // Utils
+ //
+
+ struct ggml_backend_graph_copy {
+ ggml_backend_buffer_t buffer;
+ struct ggml_context * ctx_allocated;
+ struct ggml_context * ctx_unallocated;
+ struct ggml_cgraph * graph;
+ };
+
+ // Copy a graph to a different backend
+ GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
+ GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
+
+ typedef bool (*GGML_CALL ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
+
+ // Compare the output of two backends
+ GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
+
+ // Tensor initialization
+ GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
+ GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
+
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml-blas.h b/ggml/include/ggml-blas.h
new file mode 100644
index 00000000..f2e37de0
--- /dev/null
+++ b/ggml/include/ggml-blas.h
@@ -0,0 +1,23 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// backend API
+GGML_API GGML_CALL ggml_backend_t ggml_backend_blas_init(void);
+
+GGML_API GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend);
+
+// number of threads used for conversion to float
+// for openblas and blis, this will also set the number of threads used for blas operations
+GGML_API GGML_CALL void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
+
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml-cann.h b/ggml/include/ggml-cann.h
new file mode 100644
index 00000000..ca73211f
--- /dev/null
+++ b/ggml/include/ggml-cann.h
@@ -0,0 +1,125 @@
+/*
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml-backend.h"
+#include "ggml.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * @brief Maximum number of CANN devices supported.
+ */
+#define GGML_CANN_MAX_DEVICES 16
+
+/**
+ * @brief Initializes the CANN backend for a specified device.
+ *
+ * This function initializes the CANN backend for the given device.
+ * It verifies the device index, allocates a context, and creates a backend
+ * instance.
+ *
+ * @param device The index of the device to initialize.
+ * @return A pointer to the initialized backend instance, or nullptr on failure.
+ */
+GGML_API GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device);
+
+/**
+ * @brief Checks if a given backend is a CANN backend.
+ *
+ * This function verifies if the provided backend is a CANN backend by comparing
+ * its GUID with the CANN backend's GUID.
+ *
+ * @param backend The backend instance to check.
+ * @return True if the backend is a CANN backend, false otherwise.
+ */
+GGML_API GGML_CALL bool ggml_backend_is_cann(ggml_backend_t backend);
+
+/**
+ * @brief Retrieves the CANN buffer type for a specified device.
+ *
+ * This function initializes and returns the buffer type interface associated
+ * with the given device. It ensures thread-safe access using a mutex.
+ *
+ * @param device The device index for which to retrieve the buffer type.
+ * @return A pointer to the buffer type interface for the specified device, or
+ * nullptr if the device index is out of range.
+ */
+GGML_API GGML_CALL ggml_backend_buffer_type_t
+ggml_backend_cann_buffer_type(int32_t device);
+
+/**
+ * @brief Retrieves the number of CANN devices available.
+ *
+ * This function returns the number of CANN devices available based on
+ * information obtained from `ggml_cann_info()`.
+ *
+ * @return The number of CANN devices available.
+ */
+GGML_API GGML_CALL int32_t ggml_backend_cann_get_device_count(void);
+
+/**
+ * @brief Retrieves the description of a specific CANN device.
+ *
+ * This function sets the specified device, retrieves the SoC name,
+ * and writes it into the provided description buffer.
+ *
+ * @param device The device index to retrieve the description for.
+ * @param description Pointer to a buffer where the description will be written.
+ * @param description_size Size of the description buffer.
+ */
+GGML_API GGML_CALL void ggml_backend_cann_get_device_description(
+ int32_t device, char* description, size_t description_size);
+
+/**
+ * @brief Retrieves the memory information of a specific CANN device.
+ *
+ * This function sets the specified device, retrieves the free and total
+ * memory information of the specified type (ACL_HBM_MEM), and stores them
+ * in the provided pointers.
+ *
+ * @param device The device index to retrieve memory information for.
+ * @param free Pointer to a variable where the free memory size will be stored.
+ * @param total Pointer to a variable where the total memory size will be
+ * stored.
+ */
+GGML_API GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device,
+ size_t* free,
+ size_t* total);
+
+/**
+ * @brief Set the logging callback for GGML.
+ *
+ * This function sets the logging callback and user data for logging.
+ *
+ * @param log_callback The logging callback to set.
+ * @param user_data User data to pass to the logging callback.
+ */
+GGML_API void ggml_backend_cann_log_set_callback(ggml_log_callback log_callback,
+ void* user_data);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml-cuda.h b/ggml/include/ggml-cuda.h
new file mode 100644
index 00000000..d7903c66
--- /dev/null
+++ b/ggml/include/ggml-cuda.h
@@ -0,0 +1,44 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef GGML_USE_HIPBLAS
+#define GGML_CUDA_NAME "ROCm"
+#define GGML_CUBLAS_NAME "hipBLAS"
+#else
+#define GGML_CUDA_NAME "CUDA"
+#define GGML_CUBLAS_NAME "cuBLAS"
+#endif
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define GGML_CUDA_MAX_DEVICES 16
+
+// backend API
+GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device);
+
+GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend);
+
+// device buffer
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
+
+// split tensor buffer that splits matrices by rows across multiple devices
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
+
+// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
+
+GGML_API GGML_CALL int ggml_backend_cuda_get_device_count(void);
+GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
+GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
+
+GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
+GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
+
+GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data);
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml-kompute.h b/ggml/include/ggml-kompute.h
new file mode 100644
index 00000000..17146545
--- /dev/null
+++ b/ggml/include/ggml-kompute.h
@@ -0,0 +1,46 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct ggml_vk_device {
+ int index;
+ int type; // same as VkPhysicalDeviceType
+ size_t heapSize;
+ const char * name;
+ const char * vendor;
+ int subgroupSize;
+ uint64_t bufferAlignment;
+ uint64_t maxAlloc;
+};
+
+struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
+bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
+bool ggml_vk_has_vulkan(void);
+bool ggml_vk_has_device(void);
+struct ggml_vk_device ggml_vk_current_device(void);
+
+//
+// backend API
+//
+
+// forward declaration
+typedef struct ggml_backend * ggml_backend_t;
+
+GGML_API ggml_backend_t ggml_backend_kompute_init(int device);
+
+GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend);
+
+GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml-metal.h b/ggml/include/ggml-metal.h
new file mode 100644
index 00000000..6c3226c3
--- /dev/null
+++ b/ggml/include/ggml-metal.h
@@ -0,0 +1,65 @@
+// An interface allowing to compute ggml_cgraph with Metal
+//
+// This is a fully functional interface that extends ggml with GPU support for Apple devices.
+// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.)
+//
+// How it works?
+//
+// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this
+// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you
+// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.)
+//
+// You only need to make sure that all memory buffers that you used during the graph creation
+// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is
+// used during the graph evaluation to determine the arguments of the compute kernels.
+//
+// Synchronization between device and host memory (for example for input and output tensors)
+// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions.
+//
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#include <stddef.h>
+#include <stdbool.h>
+
+// max memory buffers that can be mapped to the device
+#define GGML_METAL_MAX_BUFFERS 64
+
+struct ggml_tensor;
+struct ggml_cgraph;
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+//
+// backend API
+// user-code should use only these functions
+//
+
+GGML_API void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data);
+
+GGML_API ggml_backend_t ggml_backend_metal_init(void);
+
+GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
+
+GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size);
+
+GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
+
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
+
+// helper to check if the device supports a specific family
+// ideally, the user code should be doing these checks
+// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
+GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
+
+// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
+GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h
new file mode 100644
index 00000000..aa144832
--- /dev/null
+++ b/ggml/include/ggml-rpc.h
@@ -0,0 +1,24 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define GGML_RPC_MAX_SERVERS 16
+
+// backend API
+GGML_API GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
+GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend);
+
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
+
+GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
+
+GGML_API GGML_CALL void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml-sycl.h b/ggml/include/ggml-sycl.h
new file mode 100644
index 00000000..43ab1519
--- /dev/null
+++ b/ggml/include/ggml-sycl.h
@@ -0,0 +1,42 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#define GGML_SYCL_NAME "SYCL"
+#define GGML_SYCL_MAX_DEVICES 48
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// backend API
+GGML_API ggml_backend_t ggml_backend_sycl_init(int device);
+
+// devide buffer
+GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
+
+// split tensor buffer that splits matrices by rows across multiple devices
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);
+
+// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
+GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
+
+GGML_API void ggml_backend_sycl_print_sycl_devices(void);
+GGML_API GGML_CALL void ggml_sycl_get_gpu_list(int *id_list, int max_len);
+GGML_API GGML_CALL void ggml_sycl_get_device_description(int device, char *description, size_t description_size);
+GGML_API GGML_CALL int ggml_backend_sycl_get_device_count();
+GGML_API GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
+
+// SYCL doesn't support registering host memory, keep here for reference
+// GGML_API GGML_CALL bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
+// GGML_API GGML_CALL void ggml_backend_sycl_unregister_host_buffer(void * buffer);
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml-vulkan.h b/ggml/include/ggml-vulkan.h
new file mode 100644
index 00000000..af661c2d
--- /dev/null
+++ b/ggml/include/ggml-vulkan.h
@@ -0,0 +1,29 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define GGML_VK_NAME "Vulkan"
+#define GGML_VK_MAX_DEVICES 16
+
+GGML_API void ggml_vk_instance_init(void);
+
+// backend API
+GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num);
+
+GGML_API GGML_CALL bool ggml_backend_is_vk(ggml_backend_t backend);
+GGML_API GGML_CALL int ggml_backend_vk_get_device_count(void);
+GGML_API GGML_CALL void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size);
+GGML_API GGML_CALL void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total);
+
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
+// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
new file mode 100644
index 00000000..5ed8f73d
--- /dev/null
+++ b/ggml/include/ggml.h
@@ -0,0 +1,2453 @@
+#pragma once
+
+//
+// GGML Tensor Library
+//
+// This documentation is still a work in progress.
+// If you wish some specific topics to be covered, feel free to drop a comment:
+//
+// https://github.com/ggerganov/whisper.cpp/issues/40
+//
+// ## Overview
+//
+// This library implements:
+//
+// - a set of tensor operations
+// - automatic differentiation
+// - basic optimization algorithms
+//
+// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes,
+// but is not limited to, the following:
+//
+// - linear regression
+// - support vector machines
+// - neural networks
+//
+// The library allows the user to define a certain function using the available tensor operations. This function
+// definition is represented internally via a computation graph. Each tensor operation in the function definition
+// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the
+// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized
+// using one of the available optimization algorithms.
+//
+// For example, here we define the function: f(x) = a*x^2 + b
+//
+// {
+// struct ggml_init_params params = {
+// .mem_size = 16*1024*1024,
+// .mem_buffer = NULL,
+// };
+//
+// // memory allocation happens here
+// struct ggml_context * ctx = ggml_init(params);
+//
+// struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+//
+// ggml_set_param(ctx, x); // x is an input variable
+//
+// struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+// struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+// struct ggml_tensor * x2 = ggml_mul(ctx, x, x);
+// struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b);
+//
+// ...
+// }
+//
+// Notice that the function definition above does not involve any actual computation. The computation is performed only
+// when the user explicitly requests it. For example, to compute the function's value at x = 2.0:
+//
+// {
+// ...
+//
+// struct ggml_cgraph * gf = ggml_new_graph(ctx);
+// ggml_build_forward_expand(gf, f);
+//
+// // set the input variable and parameter values
+// ggml_set_f32(x, 2.0f);
+// ggml_set_f32(a, 3.0f);
+// ggml_set_f32(b, 4.0f);
+//
+// ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
+//
+// printf("f = %f\n", ggml_get_f32_1d(f, 0));
+//
+// ...
+// }
+//
+// The actual computation is performed in the ggml_graph_compute() function.
+//
+// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the
+// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know
+// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory
+// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was
+// actually needed.
+//
+// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic
+// differentiation and optimization algorithms.
+//
+// The described approach allows to define the function graph once and then compute its forward or backward graphs
+// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way
+// the user can avoid the memory allocation overhead at runtime.
+//
+// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class
+// citizens, but in theory the library can be extended to support FP8 and integer data types.
+//
+// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary
+// and binary operations. Most of the available operations fall into one of these two categories. With time, it became
+// clear that the library needs to support more complex operations. The way to support these operations is not clear
+// yet, but a few examples are demonstrated in the following operations:
+//
+// - ggml_permute()
+// - ggml_conv_1d_1s()
+// - ggml_conv_1d_2s()
+//
+// For each tensor operator, the library implements a forward and backward computation function. The forward function
+// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the
+// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a
+// calculus class, or watch the following video:
+//
+// What is Automatic Differentiation?
+// https://www.youtube.com/watch?v=wG_nF1awSSY
+//
+//
+// ## Tensor data (struct ggml_tensor)
+//
+// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of
+// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains
+// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example:
+//
+// {
+// struct ggml_tensor * c = ggml_add(ctx, a, b);
+//
+// assert(c->src[0] == a);
+// assert(c->src[1] == b);
+// }
+//
+// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the
+// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows
+// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and
+// permutation. All tensor operations have to take the stride into account and not assume that the tensor is
+// contiguous in memory.
+//
+// The data of the tensor is accessed via the "data" pointer. For example:
+//
+// {
+// const int nx = 2;
+// const int ny = 3;
+//
+// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, ny);
+//
+// for (int y = 0; y < ny; y++) {
+// for (int x = 0; x < nx; x++) {
+// *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y;
+// }
+// }
+//
+// ...
+// }
+//
+// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used.
+//
+// ## The matrix multiplication operator (ggml_mul_mat)
+//
+// TODO
+//
+//
+// ## Multi-threading
+//
+// TODO
+//
+//
+// ## Overview of ggml.c
+//
+// TODO
+//
+//
+// ## SIMD optimizations
+//
+// TODO
+//
+//
+// ## Debugging ggml
+//
+// TODO
+//
+//
+
+#ifdef GGML_SHARED
+# if defined(_WIN32) && !defined(__MINGW32__)
+# ifdef GGML_BUILD
+# define GGML_API __declspec(dllexport)
+# else
+# define GGML_API __declspec(dllimport)
+# endif
+# else
+# define GGML_API __attribute__ ((visibility ("default")))
+# endif
+#else
+# define GGML_API
+#endif
+
+#ifdef GGML_MULTIPLATFORM
+# if defined(_WIN32)
+# define GGML_CALL
+# else
+# define GGML_CALL __attribute__((__ms_abi__))
+# endif
+#else
+# define GGML_CALL
+#endif
+
+// TODO: support for clang
+#ifdef __GNUC__
+# define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
+#elif defined(_MSC_VER)
+# define GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
+#else
+# define GGML_DEPRECATED(func, hint) func
+#endif
+
+#ifndef __GNUC__
+# define GGML_ATTRIBUTE_FORMAT(...)
+#elif defined(__MINGW32__)
+# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
+#else
+# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
+#endif
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdio.h>
+
+#define GGML_FILE_MAGIC 0x67676d6c // "ggml"
+#define GGML_FILE_VERSION 1
+
+#define GGML_QNT_VERSION 2 // bump this on quantization format changes
+#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
+
+#define GGML_MAX_DIMS 4
+#define GGML_MAX_PARAMS 2048
+#define GGML_MAX_CONTEXTS 64
+#define GGML_MAX_SRC 10
+#ifndef GGML_MAX_NAME
+#define GGML_MAX_NAME 64
+#endif
+#define GGML_MAX_OP_PARAMS 64
+#define GGML_DEFAULT_N_THREADS 4
+#define GGML_DEFAULT_GRAPH_SIZE 2048
+#if UINTPTR_MAX == 0xFFFFFFFF
+ #define GGML_MEM_ALIGN 4
+#else
+ #define GGML_MEM_ALIGN 16
+#endif
+
+#define GGML_EXIT_SUCCESS 0
+#define GGML_EXIT_ABORTED 1
+
+#define GGUF_MAGIC "GGUF"
+
+#define GGUF_VERSION 3
+
+#define GGUF_DEFAULT_ALIGNMENT 32
+
+#define GGML_UNUSED(x) (void)(x)
+
+#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
+
+#define GGML_ASSERT(x) \
+ do { \
+ if (!(x)) { \
+ fflush(stdout); \
+ fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
+ ggml_print_backtrace(); \
+ abort(); \
+ } \
+ } while (0)
+
+#ifndef NDEBUG
+#define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached")
+#elif defined(__GNUC__)
+#define GGML_UNREACHABLE() __builtin_unreachable()
+#elif defined(_MSC_VER)
+#define GGML_UNREACHABLE() __assume(0)
+#else
+#define GGML_UNREACHABLE() ((void) 0)
+#endif
+
+// used to copy the number of elements and stride in bytes of tensors into local variables.
+// main purpose is to reduce code duplication and improve readability.
+//
+// example:
+//
+// GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
+// GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
+//
+#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
+ const type prefix##0 = (pointer)->array[0]; \
+ GGML_UNUSED(prefix##0);
+#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
+ GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \
+ const type prefix##1 = (pointer)->array[1]; \
+ GGML_UNUSED(prefix##1);
+#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
+ GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \
+ const type prefix##2 = (pointer)->array[2]; \
+ GGML_UNUSED(prefix##2);
+#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
+ GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \
+ const type prefix##3 = (pointer)->array[3]; \
+ GGML_UNUSED(prefix##3);
+
+#define GGML_TENSOR_UNARY_OP_LOCALS \
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
+#define GGML_TENSOR_BINARY_OP_LOCALS \
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
+#define GGML_TENSOR_BINARY_OP_LOCALS01 \
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+ enum ggml_status {
+ GGML_STATUS_ALLOC_FAILED = -2,
+ GGML_STATUS_FAILED = -1,
+ GGML_STATUS_SUCCESS = 0,
+ GGML_STATUS_ABORTED = 1,
+ };
+
+ // get ggml_status name string
+ GGML_API GGML_CALL const char * ggml_status_to_string(enum ggml_status status);
+
+ // ieee 754-2008 half-precision float16
+ // todo: make this not an integral type
+ typedef uint16_t ggml_fp16_t;
+ GGML_API float ggml_fp16_to_fp32(ggml_fp16_t);
+ GGML_API ggml_fp16_t ggml_fp32_to_fp16(float);
+ GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t);
+ GGML_API void ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t);
+
+ // google brain half-precision bfloat16
+ typedef struct { uint16_t bits; } ggml_bf16_t;
+ GGML_API ggml_bf16_t ggml_fp32_to_bf16(float);
+ GGML_API float ggml_bf16_to_fp32(ggml_bf16_t); // consider just doing << 16
+ GGML_API void ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t);
+ GGML_API void ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t);
+
+ struct ggml_object;
+ struct ggml_context;
+
+ // NOTE: always add types at the end of the enum to keep backward compatibility
+ enum ggml_type {
+ GGML_TYPE_F32 = 0,
+ GGML_TYPE_F16 = 1,
+ GGML_TYPE_Q4_0 = 2,
+ GGML_TYPE_Q4_1 = 3,
+ // GGML_TYPE_Q4_2 = 4, support has been removed
+ // GGML_TYPE_Q4_3 = 5, support has been removed
+ GGML_TYPE_Q5_0 = 6,
+ GGML_TYPE_Q5_1 = 7,
+ GGML_TYPE_Q8_0 = 8,
+ GGML_TYPE_Q8_1 = 9,
+ GGML_TYPE_Q2_K = 10,
+ GGML_TYPE_Q3_K = 11,
+ GGML_TYPE_Q4_K = 12,
+ GGML_TYPE_Q5_K = 13,
+ GGML_TYPE_Q6_K = 14,
+ GGML_TYPE_Q8_K = 15,
+ GGML_TYPE_IQ2_XXS = 16,
+ GGML_TYPE_IQ2_XS = 17,
+ GGML_TYPE_IQ3_XXS = 18,
+ GGML_TYPE_IQ1_S = 19,
+ GGML_TYPE_IQ4_NL = 20,
+ GGML_TYPE_IQ3_S = 21,
+ GGML_TYPE_IQ2_S = 22,
+ GGML_TYPE_IQ4_XS = 23,
+ GGML_TYPE_I8 = 24,
+ GGML_TYPE_I16 = 25,
+ GGML_TYPE_I32 = 26,
+ GGML_TYPE_I64 = 27,
+ GGML_TYPE_F64 = 28,
+ GGML_TYPE_IQ1_M = 29,
+ GGML_TYPE_BF16 = 30,
+ GGML_TYPE_Q4_0_4_4 = 31,
+ GGML_TYPE_Q4_0_4_8 = 32,
+ GGML_TYPE_Q4_0_8_8 = 33,
+ GGML_TYPE_IQ1_BN = 34,
+ GGML_TYPE_IQ2_BN = 35,
+ GGML_TYPE_Q8_K64 = 36,
+ GGML_TYPE_COUNT,
+ };
+
+ // precision
+ enum ggml_prec {
+ GGML_PREC_DEFAULT,
+ GGML_PREC_F32,
+ };
+
+ enum ggml_backend_type {
+ GGML_BACKEND_TYPE_CPU = 0,
+ GGML_BACKEND_TYPE_GPU = 10,
+ GGML_BACKEND_TYPE_GPU_SPLIT = 20,
+ };
+
+ // model file types
+ enum ggml_ftype {
+ GGML_FTYPE_UNKNOWN = -1,
+ GGML_FTYPE_ALL_F32 = 0,
+ GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
+ GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
+ GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
+ GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ1_BN = 28, // except 1d tensors
+ GGML_FTYPE_MOSTLY_IQ2_BN = 29, // except 1d tensors
+ };
+
+ // available tensor operations:
+ enum ggml_op {
+ GGML_OP_NONE = 0,
+
+ GGML_OP_DUP,
+ GGML_OP_ADD,
+ GGML_OP_ADD1,
+ GGML_OP_ACC,
+ GGML_OP_SUB,
+ GGML_OP_MUL,
+ GGML_OP_DIV,
+ GGML_OP_SQR,
+ GGML_OP_SQRT,
+ GGML_OP_LOG,
+ GGML_OP_SUM,
+ GGML_OP_SUM_ROWS,
+ GGML_OP_MEAN,
+ GGML_OP_ARGMAX,
+ GGML_OP_REPEAT,
+ GGML_OP_REPEAT_BACK,
+ GGML_OP_CONCAT,
+ GGML_OP_SILU_BACK,
+ GGML_OP_NORM, // normalize
+ GGML_OP_RMS_NORM,
+ GGML_OP_RMS_NORM_BACK,
+ GGML_OP_GROUP_NORM,
+
+ GGML_OP_MUL_MAT,
+ GGML_OP_MUL_MAT_ID,
+ GGML_OP_OUT_PROD,
+
+ GGML_OP_SCALE,
+ GGML_OP_SET,
+ GGML_OP_CPY,
+ GGML_OP_CONT,
+ GGML_OP_RESHAPE,
+ GGML_OP_VIEW,
+ GGML_OP_PERMUTE,
+ GGML_OP_TRANSPOSE,
+ GGML_OP_GET_ROWS,
+ GGML_OP_GET_ROWS_BACK,
+ GGML_OP_DIAG,
+ GGML_OP_DIAG_MASK_INF,
+ GGML_OP_DIAG_MASK_ZERO,
+ GGML_OP_SOFT_MAX,
+ GGML_OP_SOFT_MAX_BACK,
+ GGML_OP_ROPE,
+ GGML_OP_ROPE_BACK,
+ GGML_OP_CLAMP,
+ GGML_OP_CONV_TRANSPOSE_1D,
+ GGML_OP_IM2COL,
+ GGML_OP_CONV_TRANSPOSE_2D,
+ GGML_OP_POOL_1D,
+ GGML_OP_POOL_2D,
+ GGML_OP_UPSCALE, // nearest interpolate
+ GGML_OP_PAD,
+ GGML_OP_ARANGE,
+ GGML_OP_TIMESTEP_EMBEDDING,
+ GGML_OP_ARGSORT,
+ GGML_OP_LEAKY_RELU,
+
+ GGML_OP_FLASH_ATTN_EXT,
+ GGML_OP_FLASH_ATTN_BACK,
+ GGML_OP_SSM_CONV,
+ GGML_OP_SSM_SCAN,
+ GGML_OP_WIN_PART,
+ GGML_OP_WIN_UNPART,
+ GGML_OP_GET_REL_POS,
+ GGML_OP_ADD_REL_POS,
+
+ GGML_OP_UNARY,
+
+ GGML_OP_MAP_UNARY,
+ GGML_OP_MAP_BINARY,
+
+ GGML_OP_MAP_CUSTOM1_F32,
+ GGML_OP_MAP_CUSTOM2_F32,
+ GGML_OP_MAP_CUSTOM3_F32,
+
+ GGML_OP_MAP_CUSTOM1,
+ GGML_OP_MAP_CUSTOM2,
+ GGML_OP_MAP_CUSTOM3,
+
+ GGML_OP_CROSS_ENTROPY_LOSS,
+ GGML_OP_CROSS_ENTROPY_LOSS_BACK,
+
+ GGML_OP_COUNT,
+ };
+
+ enum ggml_unary_op {
+ GGML_UNARY_OP_ABS,
+ GGML_UNARY_OP_SGN,
+ GGML_UNARY_OP_NEG,
+ GGML_UNARY_OP_STEP,
+ GGML_UNARY_OP_TANH,
+ GGML_UNARY_OP_ELU,
+ GGML_UNARY_OP_RELU,
+ GGML_UNARY_OP_SIGMOID,
+ GGML_UNARY_OP_GELU,
+ GGML_UNARY_OP_GELU_QUICK,
+ GGML_UNARY_OP_SILU,
+ GGML_UNARY_OP_HARDSWISH,
+ GGML_UNARY_OP_HARDSIGMOID,
+
+ GGML_UNARY_OP_COUNT,
+ };
+
+ enum ggml_object_type {
+ GGML_OBJECT_TYPE_TENSOR,
+ GGML_OBJECT_TYPE_GRAPH,
+ GGML_OBJECT_TYPE_WORK_BUFFER
+ };
+
+ enum ggml_log_level {
+ GGML_LOG_LEVEL_ERROR = 2,
+ GGML_LOG_LEVEL_WARN = 3,
+ GGML_LOG_LEVEL_INFO = 4,
+ GGML_LOG_LEVEL_DEBUG = 5
+ };
+
+ enum ggml_tensor_flag {
+ GGML_TENSOR_FLAG_INPUT = 1,
+ GGML_TENSOR_FLAG_OUTPUT = 2,
+ GGML_TENSOR_FLAG_PARAM = 4,
+ };
+
+ // ggml object
+ struct ggml_object {
+ size_t offs;
+ size_t size;
+
+ struct ggml_object * next;
+
+ enum ggml_object_type type;
+
+ char padding[4];
+ };
+
+ static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
+
+ // n-dimensional tensor
+ struct ggml_tensor {
+ enum ggml_type type;
+
+ GGML_DEPRECATED(enum ggml_backend_type backend, "use the buffer type to find the storage location of the tensor");
+
+ struct ggml_backend_buffer * buffer;
+
+ int64_t ne[GGML_MAX_DIMS]; // number of elements
+ size_t nb[GGML_MAX_DIMS]; // stride in bytes:
+ // nb[0] = ggml_type_size(type)
+ // nb[1] = nb[0] * (ne[0] / ggml_blck_size(type)) + padding
+ // nb[i] = nb[i-1] * ne[i-1]
+
+ // compute data
+ enum ggml_op op;
+
+ // op params - allocated as int32_t for alignment
+ int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
+
+ int32_t flags;
+
+ struct ggml_tensor * grad;
+ struct ggml_tensor * src[GGML_MAX_SRC];
+
+ // source tensor and offset for views
+ struct ggml_tensor * view_src;
+ size_t view_offs;
+
+ void * data;
+
+ char name[GGML_MAX_NAME];
+
+ void * extra; // extra things e.g. for ggml-cuda.cu
+
+ // char padding[4];
+ };
+
+ static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
+
+ // Abort callback
+ // If not NULL, called before ggml computation
+ // If it returns true, the computation is aborted
+ typedef bool (*ggml_abort_callback)(void * data);
+
+ // the compute plan that needs to be prepared for ggml_graph_compute()
+ // since https://github.com/ggerganov/ggml/issues/287
+ struct ggml_cplan {
+ size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()`
+ uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
+
+ int n_threads;
+
+ // abort ggml_graph_compute when true
+ ggml_abort_callback abort_callback;
+ void * abort_callback_data;
+ };
+
+ enum ggml_cgraph_eval_order {
+ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
+ GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
+ GGML_CGRAPH_EVAL_ORDER_COUNT
+ };
+
+ struct ggml_hash_set {
+ size_t size;
+ struct ggml_tensor ** keys;
+ };
+
+ // computation graph
+ struct ggml_cgraph {
+ int size;
+ int n_nodes;
+ int n_leafs;
+
+ struct ggml_tensor ** nodes;
+ struct ggml_tensor ** grads;
+ struct ggml_tensor ** leafs;
+
+ struct ggml_hash_set visited_hash_table;
+
+ enum ggml_cgraph_eval_order order;
+ };
+
+ // scratch buffer
+ struct ggml_scratch {
+ size_t offs;
+ size_t size;
+ void * data;
+ };
+
+ struct ggml_init_params {
+ // memory pool
+ size_t mem_size; // bytes
+ void * mem_buffer; // if NULL, memory will be allocated internally
+ bool no_alloc; // don't allocate memory for the tensor data
+ };
+
+ // numa strategies
+ enum ggml_numa_strategy {
+ GGML_NUMA_STRATEGY_DISABLED = 0,
+ GGML_NUMA_STRATEGY_DISTRIBUTE = 1,
+ GGML_NUMA_STRATEGY_ISOLATE = 2,
+ GGML_NUMA_STRATEGY_NUMACTL = 3,
+ GGML_NUMA_STRATEGY_MIRROR = 4,
+ GGML_NUMA_STRATEGY_COUNT
+ };
+
+ //
+ // GUID
+ //
+
+ // GUID types
+ typedef uint8_t ggml_guid[16];
+ typedef ggml_guid * ggml_guid_t;
+
+ GGML_API bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b);
+
+ // misc
+
+ GGML_API void ggml_time_init(void); // call this once at the beginning of the program
+ GGML_API int64_t ggml_time_ms(void);
+ GGML_API int64_t ggml_time_us(void);
+ GGML_API int64_t ggml_cycles(void);
+ GGML_API int64_t ggml_cycles_per_ms(void);
+
+ GGML_API void ggml_print_backtrace(void);
+
+ // accepts a UTF-8 path, even on Windows
+ GGML_API FILE * ggml_fopen(const char * fname, const char * mode);
+
+ GGML_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems
+ GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node
+
+ GGML_API void ggml_print_object (const struct ggml_object * obj);
+ GGML_API void ggml_print_objects(const struct ggml_context * ctx);
+
+ GGML_API GGML_CALL int64_t ggml_nelements (const struct ggml_tensor * tensor);
+ GGML_API GGML_CALL int64_t ggml_nrows (const struct ggml_tensor * tensor);
+ GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor);
+ GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
+
+ GGML_API GGML_CALL int64_t ggml_blck_size(enum ggml_type type);
+ GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
+ GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
+
+ GGML_DEPRECATED(
+ GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
+ "use ggml_row_size() instead");
+
+ GGML_API GGML_CALL const char * ggml_type_name(enum ggml_type type);
+ GGML_API GGML_CALL const char * ggml_op_name (enum ggml_op op);
+ GGML_API const char * ggml_op_symbol(enum ggml_op op);
+
+ GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
+ GGML_API GGML_CALL const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
+
+ GGML_API GGML_CALL size_t ggml_element_size(const struct ggml_tensor * tensor);
+
+ GGML_API GGML_CALL bool ggml_is_quantized(enum ggml_type type);
+
+ // TODO: temporary until model loading of ggml examples is refactored
+ GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
+
+ GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
+ GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor);
+ GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor);
+ GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
+ GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
+ GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
+ GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
+ GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
+
+ GGML_API GGML_CALL bool ggml_is_contiguous (const struct ggml_tensor * tensor);
+ GGML_API GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
+ GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
+ GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
+
+ GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
+ GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
+
+ GGML_API bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
+
+ // use this to compute the memory overhead of a tensor
+ GGML_API size_t ggml_tensor_overhead(void);
+
+ GGML_API bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes);
+
+ // main
+
+ GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
+ GGML_API void ggml_free(struct ggml_context * ctx);
+
+ GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
+
+ GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
+ GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx);
+ GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
+
+ GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx);
+ GGML_API size_t ggml_get_mem_size (const struct ggml_context * ctx);
+ GGML_API size_t ggml_get_max_tensor_size(const struct ggml_context * ctx);
+
+ GGML_API struct ggml_tensor * ggml_new_tensor(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int n_dims,
+ const int64_t *ne);
+
+ GGML_API struct ggml_tensor * ggml_new_tensor_1d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int64_t ne0);
+
+ GGML_API struct ggml_tensor * ggml_new_tensor_2d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int64_t ne0,
+ int64_t ne1);
+
+ GGML_API struct ggml_tensor * ggml_new_tensor_3d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2);
+
+ GGML_API struct ggml_tensor * ggml_new_tensor_4d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2,
+ int64_t ne3);
+
+ GGML_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
+ GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
+
+ GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
+ GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
+
+ // Context tensor enumeration and lookup
+ GGML_API struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx);
+ GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor);
+ GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
+
+ GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
+ GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
+ GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
+
+ // Converts a flat index into coordinates
+ GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
+
+ GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
+ GGML_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
+
+ GGML_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+ GGML_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
+
+ GGML_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
+ GGML_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
+
+ GGML_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+ GGML_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
+
+ GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
+ GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
+
+ GGML_API GGML_CALL enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
+
+ GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor);
+ GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name);
+ GGML_ATTRIBUTE_FORMAT(2, 3)
+ GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...);
+
+ //
+ // operations on tensors with backpropagation
+ //
+
+ GGML_API struct ggml_tensor * ggml_dup(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ // in-place, returns view(a)
+ GGML_API struct ggml_tensor * ggml_dup_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_add(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ GGML_API struct ggml_tensor * ggml_add_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ GGML_API struct ggml_tensor * ggml_add_cast(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_type type);
+
+ GGML_API struct ggml_tensor * ggml_add1(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ GGML_API struct ggml_tensor * ggml_add1_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ // dst = a
+ // view(dst, nb1, nb2, nb3, offset) += b
+ // return dst
+ GGML_API struct ggml_tensor * ggml_acc(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t nb1,
+ size_t nb2,
+ size_t nb3,
+ size_t offset);
+
+ GGML_API struct ggml_tensor * ggml_acc_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t nb1,
+ size_t nb2,
+ size_t nb3,
+ size_t offset);
+
+ GGML_API struct ggml_tensor * ggml_sub(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ GGML_API struct ggml_tensor * ggml_sub_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ GGML_API struct ggml_tensor * ggml_mul(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ GGML_API struct ggml_tensor * ggml_mul_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ GGML_API struct ggml_tensor * ggml_div(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ GGML_API struct ggml_tensor * ggml_div_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ GGML_API struct ggml_tensor * ggml_sqr(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_sqr_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_sqrt(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_sqrt_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_log(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_log_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ // return scalar
+ GGML_API struct ggml_tensor * ggml_sum(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d]
+ GGML_API struct ggml_tensor * ggml_sum_rows(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ // mean along rows
+ GGML_API struct ggml_tensor * ggml_mean(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ // argmax along rows
+ GGML_API struct ggml_tensor * ggml_argmax(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ // if a is the same shape as b, and a is not parameter, return a
+ // otherwise, return a new tensor: repeat(a) to fit in b
+ GGML_API struct ggml_tensor * ggml_repeat(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ // sums repetitions in a into shape of b
+ GGML_API struct ggml_tensor * ggml_repeat_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ // concat a and b along dim
+ // used in stable-diffusion
+ GGML_API struct ggml_tensor * ggml_concat(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int dim);
+
+ GGML_API struct ggml_tensor * ggml_abs(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_abs_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_sgn(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_sgn_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_neg(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_neg_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_step(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_step_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_tanh(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_tanh_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_elu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_elu_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_relu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_leaky_relu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a, float negative_slope, bool inplace);
+
+ GGML_API struct ggml_tensor * ggml_relu_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_sigmoid(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_sigmoid_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_gelu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_gelu_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_gelu_quick(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_gelu_quick_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_silu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_silu_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ // a - x
+ // b - dy
+ GGML_API struct ggml_tensor * ggml_silu_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ // hardswish(x) = x * relu6(x + 3) / 6
+ GGML_API struct ggml_tensor * ggml_hardswish(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ // hardsigmoid(x) = relu6(x + 3) / 6
+ GGML_API struct ggml_tensor * ggml_hardsigmoid(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ // normalize along rows
+ GGML_API struct ggml_tensor * ggml_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float eps);
+
+ GGML_API struct ggml_tensor * ggml_norm_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float eps);
+
+ GGML_API struct ggml_tensor * ggml_rms_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float eps);
+
+ GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float eps);
+
+ // group normalize along ne0*ne1*n_groups
+ // used in stable-diffusion
+ // TODO: eps is hardcoded to 1e-6 for now
+ GGML_API struct ggml_tensor * ggml_group_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_groups);
+
+ GGML_API struct ggml_tensor * ggml_group_norm_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_groups);
+
+ // a - x
+ // b - dy
+ GGML_API struct ggml_tensor * ggml_rms_norm_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float eps);
+
+ // A: k columns, n rows => [ne03, ne02, n, k]
+ // B: k columns, m rows (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k]
+ // result is n columns, m rows => [ne03 * x, ne02 * y, m, n]
+ GGML_API struct ggml_tensor * ggml_mul_mat(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ // change the precision of a matrix multiplication
+ // set to GGML_PREC_F32 for higher precision (useful for phi-2)
+ GGML_API void ggml_mul_mat_set_prec(
+ struct ggml_tensor * a,
+ enum ggml_prec prec);
+
+ // indirect matrix multiplication
+ GGML_API struct ggml_tensor * ggml_mul_mat_id(
+ struct ggml_context * ctx,
+ struct ggml_tensor * as,
+ struct ggml_tensor * b,
+ struct ggml_tensor * ids);
+
+ // A: m columns, n rows,
+ // B: p columns, n rows,
+ // result is m columns, p rows
+ GGML_API struct ggml_tensor * ggml_out_prod(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ //
+ // operations on tensors without backpropagation
+ //
+
+ GGML_API struct ggml_tensor * ggml_scale(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float s);
+
+ // in-place, returns view(a)
+ GGML_API struct ggml_tensor * ggml_scale_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float s);
+
+ // b -> view(a,offset,nb1,nb2,3), return modified a
+ GGML_API struct ggml_tensor * ggml_set(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t nb1,
+ size_t nb2,
+ size_t nb3,
+ size_t offset);
+
+ // b -> view(a,offset,nb1,nb2,3), return view(a)
+ GGML_API struct ggml_tensor * ggml_set_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t nb1,
+ size_t nb2,
+ size_t nb3,
+ size_t offset);
+
+ GGML_API struct ggml_tensor * ggml_set_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t offset);
+
+ GGML_API struct ggml_tensor * ggml_set_1d_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t offset);
+
+ // b -> view(a,offset,nb1,nb2,3), return modified a
+ GGML_API struct ggml_tensor * ggml_set_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t nb1,
+ size_t offset);
+
+ // b -> view(a,offset,nb1,nb2,3), return view(a)
+ GGML_API struct ggml_tensor * ggml_set_2d_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t nb1,
+ size_t offset);
+
+ // a -> b, return view(b)
+ GGML_API struct ggml_tensor * ggml_cpy(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ GGML_API struct ggml_tensor * ggml_cast(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ enum ggml_type type);
+
+ // make contiguous
+ GGML_API struct ggml_tensor * ggml_cont(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ // make contiguous, with new shape
+ GGML_API struct ggml_tensor * ggml_cont_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0);
+
+ GGML_API struct ggml_tensor * ggml_cont_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1);
+
+ GGML_API struct ggml_tensor * ggml_cont_3d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2);
+
+ GGML_API struct ggml_tensor * ggml_cont_4d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2,
+ int64_t ne3);
+
+ // return view(a), b specifies the new shape
+ // TODO: when we start computing gradient, make a copy instead of view
+ GGML_API struct ggml_tensor * ggml_reshape(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ // return view(a)
+ // TODO: when we start computing gradient, make a copy instead of view
+ GGML_API struct ggml_tensor * ggml_reshape_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0);
+
+ GGML_API struct ggml_tensor * ggml_reshape_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1);
+
+ // return view(a)
+ // TODO: when we start computing gradient, make a copy instead of view
+ GGML_API struct ggml_tensor * ggml_reshape_3d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2);
+
+ GGML_API struct ggml_tensor * ggml_reshape_4d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2,
+ int64_t ne3);
+
+ // offset in bytes
+ GGML_API struct ggml_tensor * ggml_view_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ size_t offset);
+
+ GGML_API struct ggml_tensor * ggml_view_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1,
+ size_t nb1, // row stride in bytes
+ size_t offset);
+
+ GGML_API struct ggml_tensor * ggml_view_3d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2,
+ size_t nb1, // row stride in bytes
+ size_t nb2, // slice stride in bytes
+ size_t offset);
+
+ GGML_API struct ggml_tensor * ggml_view_4d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2,
+ int64_t ne3,
+ size_t nb1, // row stride in bytes
+ size_t nb2, // slice stride in bytes
+ size_t nb3,
+ size_t offset);
+
+ GGML_API struct ggml_tensor * ggml_permute(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int axis0,
+ int axis1,
+ int axis2,
+ int axis3);
+
+ // alias for ggml_permute(ctx, a, 1, 0, 2, 3)
+ GGML_API struct ggml_tensor * ggml_transpose(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ // supports 3D: a->ne[2] == b->ne[1]
+ GGML_API struct ggml_tensor * ggml_get_rows(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ GGML_API struct ggml_tensor * ggml_get_rows_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c);
+
+ GGML_API struct ggml_tensor * ggml_diag(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ // set elements above the diagonal to -INF
+ GGML_API struct ggml_tensor * ggml_diag_mask_inf(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past);
+
+ // in-place, returns view(a)
+ GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past);
+
+ // set elements above the diagonal to 0
+ GGML_API struct ggml_tensor * ggml_diag_mask_zero(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past);
+
+ // in-place, returns view(a)
+ GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past);
+
+ GGML_API struct ggml_tensor * ggml_soft_max(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ // in-place, returns view(a)
+ GGML_API struct ggml_tensor * ggml_soft_max_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ // fused soft_max(a*scale + mask*(ALiBi slope))
+ // mask is optional
+ // max_bias = 0.0f for no ALiBi
+ GGML_API struct ggml_tensor * ggml_soft_max_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale,
+ float max_bias);
+
+ GGML_API struct ggml_tensor * ggml_soft_max_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ // in-place, returns view(a)
+ GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ // rotary position embedding
+ // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
+ // if mode & 2 == 1, GPT-NeoX style
+ //
+ // b is an int32 vector with size a->ne[2], it contains the positions
+ // c is freq factors (e.g. phi3-128k), (optional)
+ GGML_API struct ggml_tensor * ggml_rope(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int n_dims,
+ int mode);
+
+ // in-place, returns view(a)
+ GGML_API struct ggml_tensor * ggml_rope_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int n_dims,
+ int mode);
+
+ // custom RoPE
+ GGML_API struct ggml_tensor * ggml_rope_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ int n_dims,
+ int mode,
+ int n_ctx_orig,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow);
+
+ // in-place, returns view(a)
+ GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ int n_dims,
+ int mode,
+ int n_ctx_orig,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow);
+
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int n_dims,
+ int mode,
+ int n_ctx_orig,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow),
+ "use ggml_rope_ext instead");
+
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int n_dims,
+ int mode,
+ int n_ctx_orig,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow),
+ "use ggml_rope_ext_inplace instead");
+
+ // compute correction dims for YaRN RoPE scaling
+ GGML_CALL void ggml_rope_yarn_corr_dims(
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
+
+ // rotary position embedding backward, i.e compute dx from dy
+ // a - dy
+ GGML_API struct ggml_tensor * ggml_rope_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ int n_dims,
+ int mode,
+ int n_ctx_orig,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow);
+
+ // clamp
+ // in-place, returns view(a)
+ GGML_API struct ggml_tensor * ggml_clamp(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float min,
+ float max);
+
+ GGML_API struct ggml_tensor * ggml_im2col(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int s0,
+ int s1,
+ int p0,
+ int p1,
+ int d0,
+ int d1,
+ bool is_2D,
+ enum ggml_type dst_type);
+
+ GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int s0,
+ int s1,
+ int p0,
+ int p1,
+ int d0,
+ int d1);
+
+ GGML_API struct ggml_tensor * ggml_conv_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int s0, // stride
+ int p0, // padding
+ int d0); // dilation
+
+ // conv_1d with padding = half
+ // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
+ GGML_API struct ggml_tensor* ggml_conv_1d_ph(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int s,
+ int d);
+
+ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int s0,
+ int p0,
+ int d0);
+
+ GGML_API struct ggml_tensor * ggml_conv_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int s0,
+ int s1,
+ int p0,
+ int p1,
+ int d0,
+ int d1);
+
+
+ // kernel size is a->ne[0] x a->ne[1]
+ // stride is equal to kernel size
+ // padding is zero
+ // example:
+ // a: 16 16 3 768
+ // b: 1024 1024 3 1
+ // res: 64 64 768 1
+ // used in sam
+ GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ // kernel size is a->ne[0] x a->ne[1]
+ // stride is 1
+ // padding is half
+ // example:
+ // a: 3 3 256 256
+ // b: 64 64 256 1
+ // res: 64 64 256 1
+ // used in sam
+ GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int stride);
+
+ enum ggml_op_pool {
+ GGML_OP_POOL_MAX,
+ GGML_OP_POOL_AVG,
+ GGML_OP_POOL_COUNT,
+ };
+
+ GGML_API struct ggml_tensor * ggml_pool_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ enum ggml_op_pool op,
+ int k0, // kernel size
+ int s0, // stride
+ int p0); // padding
+
+ // the result will have 2*p0 padding for the first dimension
+ // and 2*p1 padding for the second dimension
+ GGML_API struct ggml_tensor * ggml_pool_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ enum ggml_op_pool op,
+ int k0,
+ int k1,
+ int s0,
+ int s1,
+ float p0,
+ float p1);
+
+ // nearest interpolate
+ // multiplies ne0 and ne1 by scale factor
+ // used in stable-diffusion
+ GGML_API struct ggml_tensor * ggml_upscale(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int scale_factor);
+
+ // nearest interpolate
+ // nearest interpolate to specified dimensions
+ // used in tortoise.cpp
+ GGML_API struct ggml_tensor * ggml_upscale_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ int ne1,
+ int ne2,
+ int ne3);
+
+ // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
+ GGML_API struct ggml_tensor * ggml_pad(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int p0,
+ int p1,
+ int p2,
+ int p3);
+
+ // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
+ // timesteps: [N,]
+ // return: [N, dim]
+ GGML_API struct ggml_tensor * ggml_timestep_embedding(
+ struct ggml_context * ctx,
+ struct ggml_tensor * timesteps,
+ int dim,
+ int max_period);
+
+ // sort rows
+ enum ggml_sort_order {
+ GGML_SORT_ORDER_ASC,
+ GGML_SORT_ORDER_DESC,
+ };
+
+ GGML_API struct ggml_tensor * ggml_argsort(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ enum ggml_sort_order order);
+
+ GGML_API struct ggml_tensor * ggml_arange(
+ struct ggml_context * ctx,
+ float start,
+ float stop,
+ float step);
+
+ // top k elements per row
+ GGML_API struct ggml_tensor * ggml_top_k(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int k);
+
+#define GGML_KQ_MASK_PAD 32
+
+ // q: [n_embd, n_batch, n_head, 1]
+ // k: [n_embd, n_kv, n_head_kv, 1]
+ // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
+ // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
+ // res: [n_embd, n_head, n_batch, 1] !! permuted !!
+ GGML_API struct ggml_tensor * ggml_flash_attn_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * q,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * mask,
+ float scale,
+ float max_bias);
+
+ GGML_API void ggml_flash_attn_ext_set_prec(
+ struct ggml_tensor * a,
+ enum ggml_prec prec);
+
+ // TODO: needs to be adapted to ggml_flash_attn_ext
+ GGML_API struct ggml_tensor * ggml_flash_attn_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * q,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * d,
+ bool masked);
+
+ GGML_API struct ggml_tensor * ggml_ssm_conv(
+ struct ggml_context * ctx,
+ struct ggml_tensor * s,
+ struct ggml_tensor * x,
+ struct ggml_tensor * c,
+ struct ggml_tensor * sq);
+
+ GGML_API struct ggml_tensor * ggml_ssm_scan(
+ struct ggml_context * ctx,
+ struct ggml_tensor * s,
+ struct ggml_tensor * x,
+ struct ggml_tensor * dt,
+ struct ggml_tensor * A,
+ struct ggml_tensor * B,
+ struct ggml_tensor * C,
+ struct ggml_tensor * sq);
+
+ // partition into non-overlapping windows with padding if needed
+ // example:
+ // a: 768 64 64 1
+ // w: 14
+ // res: 768 14 14 25
+ // used in sam
+ GGML_API struct ggml_tensor * ggml_win_part(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int w);
+
+ // reverse of ggml_win_part
+ // used in sam
+ GGML_API struct ggml_tensor * ggml_win_unpart(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int w0,
+ int h0,
+ int w);
+
+ GGML_API struct ggml_tensor * ggml_unary(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ enum ggml_unary_op op);
+
+ GGML_API struct ggml_tensor * ggml_unary_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ enum ggml_unary_op op);
+
+ // used in sam
+ GGML_API struct ggml_tensor * ggml_get_rel_pos(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int qh,
+ int kh);
+
+ // used in sam
+ GGML_API struct ggml_tensor * ggml_add_rel_pos(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * pw,
+ struct ggml_tensor * ph);
+
+ GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * pw,
+ struct ggml_tensor * ph);
+
+ // custom operators
+
+ typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
+ typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
+
+ typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *);
+ typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
+ typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
+
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ ggml_unary_op_f32_t fun),
+ "use ggml_map_custom1 instead");
+
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ ggml_unary_op_f32_t fun),
+ "use ggml_map_custom1_inplace instead");
+
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ ggml_binary_op_f32_t fun),
+ "use ggml_map_custom2 instead");
+
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ ggml_binary_op_f32_t fun),
+ "use ggml_map_custom2_inplace instead");
+
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ ggml_custom1_op_f32_t fun),
+ "use ggml_map_custom1 instead");
+
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ ggml_custom1_op_f32_t fun),
+ "use ggml_map_custom1_inplace instead");
+
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ ggml_custom2_op_f32_t fun),
+ "use ggml_map_custom2 instead");
+
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ ggml_custom2_op_f32_t fun),
+ "use ggml_map_custom2_inplace instead");
+
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ ggml_custom3_op_f32_t fun),
+ "use ggml_map_custom3 instead");
+
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ ggml_custom3_op_f32_t fun),
+ "use ggml_map_custom3_inplace instead");
+
+ // custom operators v2
+
+ typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
+ typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
+ typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
+
+ #define GGML_N_TASKS_MAX -1
+
+ GGML_API struct ggml_tensor * ggml_map_custom1(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ ggml_custom1_op_t fun,
+ int n_tasks,
+ void * userdata);
+
+ GGML_API struct ggml_tensor * ggml_map_custom1_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ ggml_custom1_op_t fun,
+ int n_tasks,
+ void * userdata);
+
+ GGML_API struct ggml_tensor * ggml_map_custom2(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ ggml_custom2_op_t fun,
+ int n_tasks,
+ void * userdata);
+
+ GGML_API struct ggml_tensor * ggml_map_custom2_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ ggml_custom2_op_t fun,
+ int n_tasks,
+ void * userdata);
+
+ GGML_API struct ggml_tensor * ggml_map_custom3(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ ggml_custom3_op_t fun,
+ int n_tasks,
+ void * userdata);
+
+ GGML_API struct ggml_tensor * ggml_map_custom3_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ ggml_custom3_op_t fun,
+ int n_tasks,
+ void * userdata);
+
+ // loss function
+
+ GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c);
+
+ //
+ // automatic differentiation
+ //
+
+ GGML_API void ggml_set_param(
+ struct ggml_context * ctx,
+ struct ggml_tensor * tensor);
+
+
+ GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
+ GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
+
+ // graph allocation in a context
+ GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
+ GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads);
+ GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
+ GGML_API struct ggml_cgraph ggml_graph_view (struct ggml_cgraph * cgraph, int i0, int i1);
+ GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
+ GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
+ GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
+
+ GGML_API size_t ggml_graph_overhead(void);
+ GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
+
+ // ggml_graph_plan() has to be called before ggml_graph_compute()
+ // when plan.work_size > 0, caller must allocate memory for plan.work_data
+ GGML_API struct ggml_cplan ggml_graph_plan (const struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
+ GGML_API enum ggml_status ggml_graph_compute ( struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
+ // same as ggml_graph_compute() but the work data is allocated as a part of the context
+ // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
+ GGML_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
+
+ GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
+
+ GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
+ GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
+
+ // print info and performance information for the graph
+ GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
+
+ // dump the graph into a file using the dot format
+ GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
+
+ // build gradient checkpointing backward graph gb for gf using provided checkpoints
+ // gb_tmp will contain original backward graph with rewritten backward process nodes,
+ // but without the second forward pass nodes.
+ GGML_API void ggml_build_backward_gradient_checkpointing(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb,
+ struct ggml_cgraph * gb_tmp,
+ struct ggml_tensor * * checkpoints,
+ int n_checkpoints);
+ //
+ // optimization
+ //
+
+ // optimization methods
+ enum ggml_opt_type {
+ GGML_OPT_TYPE_ADAM,
+ GGML_OPT_TYPE_LBFGS,
+ };
+
+ // linesearch methods
+ enum ggml_linesearch {
+ GGML_LINESEARCH_DEFAULT = 1,
+
+ GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0,
+ GGML_LINESEARCH_BACKTRACKING_WOLFE = 1,
+ GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
+ };
+
+ // optimization return values
+ enum ggml_opt_result {
+ GGML_OPT_RESULT_OK = 0,
+ GGML_OPT_RESULT_DID_NOT_CONVERGE,
+ GGML_OPT_RESULT_NO_CONTEXT,
+ GGML_OPT_RESULT_INVALID_WOLFE,
+ GGML_OPT_RESULT_FAIL,
+ GGML_OPT_RESULT_CANCEL,
+
+ GGML_LINESEARCH_FAIL = -128,
+ GGML_LINESEARCH_MINIMUM_STEP,
+ GGML_LINESEARCH_MAXIMUM_STEP,
+ GGML_LINESEARCH_MAXIMUM_ITERATIONS,
+ GGML_LINESEARCH_INVALID_PARAMETERS,
+ };
+
+ typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
+ typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
+
+ // optimization parameters
+ //
+ // see ggml.c (ggml_opt_default_params) for default values
+ //
+ struct ggml_opt_params {
+ enum ggml_opt_type type;
+
+ size_t graph_size;
+
+ int n_threads;
+
+ // delta-based convergence test
+ //
+ // if past == 0 - disabled
+ // if past > 0:
+ // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
+ //
+ int past;
+ float delta;
+
+ // maximum number of iterations without improvement
+ //
+ // if 0 - disabled
+ // if > 0:
+ // assume convergence if no cost improvement in this number of iterations
+ //
+ int max_no_improvement;
+
+ bool print_forward_graph;
+ bool print_backward_graph;
+
+ int n_gradient_accumulation;
+
+ // ADAM parameters
+ struct {
+ int n_iter;
+
+ float sched; // schedule multiplier (fixed, decay or warmup)
+ float decay; // weight decay for AdamW, use 0.0f to disable
+ int decay_min_ndim; // minimum number of tensor dimension to apply weight decay
+ float alpha; // learning rate
+ float beta1;
+ float beta2;
+ float eps; // epsilon for numerical stability
+ float eps_f; // epsilon for convergence test
+ float eps_g; // epsilon for convergence test
+ float gclip; // gradient clipping
+ } adam;
+
+ // LBFGS parameters
+ struct {
+ int m; // number of corrections to approximate the inv. Hessian
+ int n_iter;
+ int max_linesearch;
+
+ float eps; // convergence tolerance
+ float ftol; // line search tolerance
+ float wolfe;
+ float min_step;
+ float max_step;
+
+ enum ggml_linesearch linesearch;
+ } lbfgs;
+ };
+
+ struct ggml_opt_context {
+ struct ggml_context * ctx;
+ struct ggml_opt_params params;
+
+ int iter;
+ int64_t nx; // number of parameter elements
+
+ bool just_initialized;
+
+ float loss_before;
+ float loss_after;
+
+ struct {
+ struct ggml_tensor * g; // current gradient
+ struct ggml_tensor * m; // first moment
+ struct ggml_tensor * v; // second moment
+ struct ggml_tensor * pf; // past function values
+ float fx_best;
+ float fx_prev;
+ int n_no_improvement;
+ } adam;
+
+ struct {
+ struct ggml_tensor * x; // current parameters
+ struct ggml_tensor * xp; // previous parameters
+ struct ggml_tensor * g; // current gradient
+ struct ggml_tensor * gp; // previous gradient
+ struct ggml_tensor * d; // search direction
+ struct ggml_tensor * pf; // past function values
+ struct ggml_tensor * lmal; // the L-BFGS memory alpha
+ struct ggml_tensor * lmys; // the L-BFGS memory ys
+ struct ggml_tensor * lms; // the L-BFGS memory s
+ struct ggml_tensor * lmy; // the L-BFGS memory y
+ float fx_best;
+ float step;
+ int j;
+ int k;
+ int end;
+ int n_no_improvement;
+ } lbfgs;
+ };
+
+ GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
+
+ // optimize the function defined by the tensor f
+ GGML_API enum ggml_opt_result ggml_opt(
+ struct ggml_context * ctx,
+ struct ggml_opt_params params,
+ struct ggml_tensor * f);
+
+ // initialize optimizer context
+ GGML_API void ggml_opt_init(
+ struct ggml_context * ctx,
+ struct ggml_opt_context * opt,
+ struct ggml_opt_params params,
+ int64_t nx);
+
+ // continue optimizing the function defined by the tensor f
+ GGML_API enum ggml_opt_result ggml_opt_resume(
+ struct ggml_context * ctx,
+ struct ggml_opt_context * opt,
+ struct ggml_tensor * f);
+
+ // continue optimizing the function defined by the tensor f
+ GGML_API enum ggml_opt_result ggml_opt_resume_g(
+ struct ggml_context * ctx,
+ struct ggml_opt_context * opt,
+ struct ggml_tensor * f,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb,
+ ggml_opt_callback callback,
+ void * callback_data);
+
+ //
+ // tensor flags
+ //
+ GGML_API void ggml_set_input(struct ggml_tensor * tensor);
+ GGML_API void ggml_set_output(struct ggml_tensor * tensor);
+
+ //
+ // quantization
+ //
+
+ // - ggml_quantize_init can be called multiple times with the same type
+ // it will only initialize the quantization tables for the first call or after ggml_quantize_free
+ // automatically called by ggml_quantize_chunk for convenience
+ //
+ // - ggml_quantize_free will free any memory allocated by ggml_quantize_init
+ // call this at the end of the program to avoid memory leaks
+ //
+ // note: these are thread-safe
+ //
+ GGML_API void ggml_quantize_init(enum ggml_type type);
+ GGML_API void ggml_quantize_free(void);
+
+ // some quantization type cannot be used without an importance matrix
+ GGML_API bool ggml_quantize_requires_imatrix(enum ggml_type type);
+
+ // calls ggml_quantize_init internally (i.e. can allocate memory)
+ GGML_API size_t ggml_quantize_chunk(
+ enum ggml_type type,
+ const float * src,
+ void * dst,
+ int64_t start,
+ int64_t nrows,
+ int64_t n_per_row,
+ const float * imatrix);
+
+ //
+ // gguf
+ //
+
+ enum gguf_type {
+ GGUF_TYPE_UINT8 = 0,
+ GGUF_TYPE_INT8 = 1,
+ GGUF_TYPE_UINT16 = 2,
+ GGUF_TYPE_INT16 = 3,
+ GGUF_TYPE_UINT32 = 4,
+ GGUF_TYPE_INT32 = 5,
+ GGUF_TYPE_FLOAT32 = 6,
+ GGUF_TYPE_BOOL = 7,
+ GGUF_TYPE_STRING = 8,
+ GGUF_TYPE_ARRAY = 9,
+ GGUF_TYPE_UINT64 = 10,
+ GGUF_TYPE_INT64 = 11,
+ GGUF_TYPE_FLOAT64 = 12,
+ GGUF_TYPE_COUNT, // marks the end of the enum
+ };
+
+ struct gguf_context;
+
+ struct gguf_init_params {
+ bool no_alloc;
+
+ // if not NULL, create a ggml_context and allocate the tensor data in it
+ struct ggml_context ** ctx;
+ };
+
+ GGML_API struct gguf_context * gguf_init_empty(void);
+ GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
+ //GGML_API struct gguf_context * gguf_init_from_buffer(..);
+
+ GGML_API void gguf_free(struct gguf_context * ctx);
+
+ GGML_API const char * gguf_type_name(enum gguf_type type);
+
+ GGML_API int gguf_get_version (const struct gguf_context * ctx);
+ GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx);
+ GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
+ GGML_API void * gguf_get_data (const struct gguf_context * ctx);
+
+ GGML_API int gguf_get_n_kv(const struct gguf_context * ctx);
+ GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key);
+ GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id);
+
+ GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id);
+ GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id);
+
+ // will abort if the wrong type is used for the key
+ GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int key_id);
+ GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int key_id);
+ GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int key_id);
+ GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int key_id);
+ GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int key_id);
+ GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int key_id);
+ GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int key_id);
+ GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int key_id);
+ GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int key_id);
+ GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
+ GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
+ GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
+ GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
+ GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
+ GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
+ GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
+
+ GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
+ GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
+ GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
+ GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i);
+ GGML_API enum ggml_type gguf_get_tensor_type (const struct gguf_context * ctx, int i);
+
+ // removes key if it exists
+ GGML_API void gguf_remove_key(struct gguf_context * ctx, const char * key);
+
+ // overrides existing values or adds a new one
+ GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
+ GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val);
+ GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
+ GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val);
+ GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
+ GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val);
+ GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val);
+ GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
+ GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val);
+ GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val);
+ GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val);
+ GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
+ GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
+ GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
+
+ // set or add KV pairs from another context
+ GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
+
+ // manage tensor info
+ GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
+ GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
+ GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
+
+ // writing gguf files can be done in 2 ways:
+ //
+ // - write the entire gguf_context to a binary file in a single pass:
+ //
+ // gguf_write_to_file(ctx, fname);
+ //
+ // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
+ //
+ // FILE * f = fopen(fname, "wb");
+ // fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
+ // fwrite(f, ...);
+ // void * data = gguf_meta_get_meta_data(ctx);
+ // fseek(f, 0, SEEK_SET);
+ // fwrite(f, data, gguf_get_meta_size(ctx));
+ // free(data);
+ // fclose(f);
+ //
+
+ // write the entire context to a binary file
+ GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
+
+ // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
+ GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
+ GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
+
+ //
+ // system info
+ //
+
+ GGML_API int ggml_cpu_has_avx (void);
+ GGML_API int ggml_cpu_has_avx_vnni (void);
+ GGML_API int ggml_cpu_has_avx2 (void);
+ GGML_API int ggml_cpu_has_avx512 (void);
+ GGML_API int ggml_cpu_has_avx512_vbmi(void);
+ GGML_API int ggml_cpu_has_avx512_vnni(void);
+ GGML_API int ggml_cpu_has_avx512_bf16(void);
+ GGML_API int ggml_cpu_has_fma (void);
+ GGML_API int ggml_cpu_has_neon (void);
+ GGML_API int ggml_cpu_has_sve (void);
+ GGML_API int ggml_cpu_has_arm_fma (void);
+ GGML_API int ggml_cpu_has_metal (void);
+ GGML_API int ggml_cpu_has_f16c (void);
+ GGML_API int ggml_cpu_has_fp16_va (void);
+ GGML_API int ggml_cpu_has_wasm_simd (void);
+ GGML_API int ggml_cpu_has_blas (void);
+ GGML_API int ggml_cpu_has_cuda (void);
+ GGML_API int ggml_cpu_has_vulkan (void);
+ GGML_API int ggml_cpu_has_kompute (void);
+ GGML_API int ggml_cpu_has_gpublas (void);
+ GGML_API int ggml_cpu_has_sse3 (void);
+ GGML_API int ggml_cpu_has_ssse3 (void);
+ GGML_API int ggml_cpu_has_sycl (void);
+ GGML_API int ggml_cpu_has_rpc (void);
+ GGML_API int ggml_cpu_has_vsx (void);
+ GGML_API int ggml_cpu_has_matmul_int8(void);
+ GGML_API int ggml_cpu_has_cann (void);
+ GGML_API int ggml_cpu_has_llamafile (void);
+
+ //
+ // Internal types and functions exposed for tests and benchmarks
+ //
+
+#ifdef __cplusplus
+// restrict not standard in C++
+#define GGML_RESTRICT
+#else
+#define GGML_RESTRICT restrict
+#endif
+ typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+ typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+ typedef void (*ggml_from_float_to_mat_t)
+ (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs);
+ typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
+ const void * GGML_RESTRICT y, size_t by, int nrc);
+ typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
+ const void * GGML_RESTRICT y, int nr, int nc);
+ typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
+ const void * GGML_RESTRICT y, int nr, int nc);
+
+ typedef struct {
+ const char * type_name;
+ int64_t blck_size;
+ int64_t blck_size_interleave; // interleave elements in blocks
+ size_t type_size;
+ bool is_quantized;
+ ggml_to_float_t to_float;
+ ggml_from_float_t from_float;
+ ggml_from_float_t from_float_ref;
+ ggml_from_float_to_mat_t from_float_to_mat;
+ ggml_vec_dot_t vec_dot;
+ enum ggml_type vec_dot_type;
+ int64_t nrows; // number of rows to process simultaneously
+ int64_t ncols; // number of columns to process simultaneously
+ ggml_gemv_t gemv;
+ ggml_gemm_t gemm;
+ } ggml_type_traits_t;
+
+ GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
new file mode 100644
index 00000000..9888313d
--- /dev/null
+++ b/ggml/src/CMakeLists.txt
@@ -0,0 +1,1291 @@
+include(CheckCXXCompilerFlag)
+
+unset(GGML_CDEF_PUBLIC)
+
+add_compile_definitions(GGML_SCHED_MAX_COPIES=${GGML_SCHED_MAX_COPIES})
+
+# enable libstdc++ assertions for debug builds
+if (CMAKE_SYSTEM_NAME MATCHES "Linux")
+ add_compile_definitions($<$<CONFIG:Debug>:_GLIBCXX_ASSERTIONS>)
+endif()
+
+if (NOT MSVC)
+ if (GGML_SANITIZE_THREAD)
+ add_compile_options(-fsanitize=thread)
+ link_libraries (-fsanitize=thread)
+ endif()
+
+ if (GGML_SANITIZE_ADDRESS)
+ add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
+ link_libraries (-fsanitize=address)
+ endif()
+
+ if (GGML_SANITIZE_UNDEFINED)
+ add_compile_options(-fsanitize=undefined)
+ link_libraries (-fsanitize=undefined)
+ endif()
+endif()
+
+if (APPLE AND GGML_ACCELERATE)
+ find_library(ACCELERATE_FRAMEWORK Accelerate)
+ if (ACCELERATE_FRAMEWORK)
+ message(STATUS "Accelerate framework found")
+
+ add_compile_definitions(GGML_USE_ACCELERATE)
+ add_compile_definitions(ACCELERATE_NEW_LAPACK)
+ add_compile_definitions(ACCELERATE_LAPACK_ILP64)
+
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
+ else()
+ message(WARNING "Accelerate framework not found")
+ endif()
+endif()
+
+if (GGML_METAL)
+ find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
+ find_library(METAL_FRAMEWORK Metal REQUIRED)
+ find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
+
+ message(STATUS "Metal framework found")
+ set(GGML_HEADERS_METAL ../include/ggml-metal.h)
+ set(GGML_SOURCES_METAL ggml-metal.m)
+
+ list(APPEND GGML_CDEF_PUBLIC GGML_USE_METAL)
+ if (GGML_METAL_NDEBUG)
+ add_compile_definitions(GGML_METAL_NDEBUG)
+ endif()
+
+ # copy ggml-common.h and ggml-metal.metal to bin directory
+ configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
+ configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
+
+ if (GGML_METAL_EMBED_LIBRARY)
+ enable_language(ASM)
+
+ add_compile_definitions(GGML_METAL_EMBED_LIBRARY)
+
+ set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/ggml-common.h")
+ set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
+
+ file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated")
+
+ # merge ggml-common.h and ggml-metal.metal into a single file
+ set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
+ set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
+
+ add_custom_command(
+ OUTPUT ${METALLIB_EMBED_ASM}
+ COMMAND echo "Embedding Metal library"
+ COMMAND sed -e '/\#include \"ggml-common.h\"/r ${METALLIB_COMMON}' -e '/\#include \"ggml-common.h\"/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED}
+ COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM}
+ COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM}
+ COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM}
+ COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM}
+ COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM}
+ COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM}
+ DEPENDS ggml-metal.metal ggml-common.h
+ COMMENT "Generate assembly for embedded Metal library"
+ )
+
+ set(GGML_SOURCES_METAL ${GGML_SOURCES_METAL} ${METALLIB_EMBED_ASM})
+ else()
+ if (GGML_METAL_SHADER_DEBUG)
+ # custom command to do the following:
+ # xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
+ # xcrun -sdk macosx metallib ggml-metal.air -o default.metallib
+ #
+ # note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works
+ # disabling fast math is needed in order to pass tests/test-backend-ops
+ # note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1
+ # note: unfortunately, we have to call it default.metallib instead of ggml.metallib
+ # ref: https://github.com/ggerganov/whisper.cpp/issues/1720
+ set(XC_FLAGS -fno-fast-math -fno-inline -g)
+ else()
+ set(XC_FLAGS -O3)
+ endif()
+
+ # Append macOS metal versioning flags
+ if (GGML_METAL_MACOSX_VERSION_MIN)
+ message(STATUS "Adding -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation")
+ list (APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN})
+ endif()
+
+ if (GGML_METAL_STD)
+ message(STATUS "Adding -std=${GGML_METAL_STD} flag to metal compilation")
+ list (APPEND XC_FLAGS -std=${GGML_METAL_STD})
+ endif()
+
+ add_custom_command(
+ OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
+ COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air
+ COMMAND xcrun -sdk macosx metallib ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
+ COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air
+ COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
+ COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
+ DEPENDS ggml-metal.metal ggml-common.h
+ COMMENT "Compiling Metal kernels"
+ )
+
+ add_custom_target(
+ ggml-metal ALL
+ DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
+ )
+ endif() # GGML_METAL_EMBED_LIBRARY
+
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS}
+ ${FOUNDATION_LIBRARY}
+ ${METAL_FRAMEWORK}
+ ${METALKIT_FRAMEWORK}
+ )
+endif()
+
+if (GGML_OPENMP)
+ find_package(OpenMP)
+ if (OpenMP_FOUND)
+ message(STATUS "OpenMP found")
+
+ add_compile_definitions(GGML_USE_OPENMP)
+
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
+ else()
+ message(WARNING "OpenMP not found")
+ endif()
+endif()
+
+if (GGML_BLAS)
+ if (GGML_STATIC)
+ set(BLA_STATIC ON)
+ endif()
+ #if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.22)
+ # set(BLA_SIZEOF_INTEGER 8)
+ #endif()
+
+ set(BLA_VENDOR ${GGML_BLAS_VENDOR})
+ find_package(BLAS)
+
+ if (BLAS_FOUND)
+ message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}")
+
+ if (("${BLAS_INCLUDE_DIRS}" STREQUAL "") AND NOT (${GGML_BLAS_VENDOR} MATCHES "Apple"))
+ # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake.
+ # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268
+ find_package(PkgConfig REQUIRED)
+ if (${GGML_BLAS_VENDOR} MATCHES "Generic")
+ pkg_check_modules(DepBLAS REQUIRED blas)
+ elseif (${GGML_BLAS_VENDOR} MATCHES "OpenBLAS")
+ # As of openblas v0.3.22, the 64-bit is named openblas64.pc
+ pkg_check_modules(DepBLAS openblas64)
+ if (NOT DepBLAS_FOUND)
+ pkg_check_modules(DepBLAS REQUIRED openblas)
+ endif()
+ elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME")
+ pkg_check_modules(DepBLAS REQUIRED blis)
+ elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS")
+ pkg_check_modules(DepBLAS REQUIRED blas-atlas)
+ elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS")
+ pkg_check_modules(DepBLAS REQUIRED flexiblas_api)
+ elseif (${GGML_BLAS_VENDOR} MATCHES "Intel")
+ # all Intel* libraries share the same include path
+ pkg_check_modules(DepBLAS REQUIRED mkl-sdl)
+ elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC")
+ # this doesn't provide pkg-config
+ # suggest to assign BLAS_INCLUDE_DIRS on your own
+ if ("${NVHPC_VERSION}" STREQUAL "")
+ message(WARNING "Better to set NVHPC_VERSION")
+ else()
+ set(DepBLAS_FOUND ON)
+ set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include")
+ endif()
+ endif()
+ if (DepBLAS_FOUND)
+ set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS})
+ else()
+ message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically"
+ " detected by pkgconfig, trying to find cblas.h from possible paths...")
+ find_path(BLAS_INCLUDE_DIRS
+ NAMES cblas.h
+ HINTS
+ /usr/include
+ /usr/local/include
+ /usr/include/openblas
+ /opt/homebrew/opt/openblas/include
+ /usr/local/opt/openblas/include
+ /usr/include/x86_64-linux-gnu/openblas/include
+ )
+ endif()
+ endif()
+
+ message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}")
+
+ add_compile_options(${BLAS_LINKER_FLAGS})
+
+ list(APPEND GGML_CDEF_PUBLIC GGML_USE_BLAS)
+
+ if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel"))
+ add_compile_definitions(GGML_BLAS_USE_MKL)
+ endif()
+
+ set(GGML_HEADERS_BLAS ../include/ggml-blas.h)
+ set(GGML_SOURCES_BLAS ggml-blas.cpp)
+
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ${BLAS_LIBRARIES})
+ set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS})
+ else()
+ message(WARNING "BLAS not found, please refer to "
+ "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
+ " to set correct GGML_BLAS_VENDOR")
+ endif()
+endif()
+
+set (GGML_SOURCES_IQK iqk/iqk_quantize.cpp)
+if (GGML_IQK_MUL_MAT)
+ message(STATUS "Using optimized iqk matrix multiplications")
+ add_compile_definitions(GGML_USE_IQK_MULMAT)
+ set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp)
+ set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h)
+endif()
+
+if (GGML_LLAMAFILE)
+ message(STATUS "Using llamafile")
+
+ add_compile_definitions(GGML_USE_LLAMAFILE)
+
+ set(GGML_HEADERS_LLAMAFILE llamafile/sgemm.h)
+ set(GGML_SOURCES_LLAMAFILE llamafile/sgemm.cpp)
+endif()
+
+if (GGML_CUDA)
+ cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
+
+ find_package(CUDAToolkit)
+
+ if (CUDAToolkit_FOUND)
+ message(STATUS "CUDA found")
+
+ if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
+ # 52 == lowest CUDA 12 standard
+ # 60 == FP16 CUDA intrinsics
+ # 61 == integer CUDA intrinsics
+ # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster
+ if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
+ set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75")
+ else()
+ set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75")
+ #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work
+ endif()
+ endif()
+ message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
+
+ enable_language(CUDA)
+
+ file(GLOB GGML_HEADERS_CUDA "ggml-cuda/*.cuh")
+ list(APPEND GGML_HEADERS_CUDA "../include/ggml-cuda.h")
+
+ file(GLOB GGML_SOURCES_CUDA "ggml-cuda/*.cu")
+ list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu")
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+
+ if (GGML_CUDA_FA_ALL_QUANTS)
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
+ else()
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ endif()
+
+ list(APPEND GGML_CDEF_PUBLIC GGML_USE_CUDA)
+
+ add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
+ add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
+ add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
+ add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
+
+ if (GGML_CUDA_USE_GRAPHS)
+ add_compile_definitions(GGML_CUDA_USE_GRAPHS)
+ endif()
+
+ if (GGML_CUDA_FORCE_DMMV)
+ add_compile_definitions(GGML_CUDA_FORCE_DMMV)
+ endif()
+
+ if (GGML_CUDA_FORCE_MMQ)
+ add_compile_definitions(GGML_CUDA_FORCE_MMQ)
+ endif()
+
+ if (GGML_CUDA_FORCE_CUBLAS)
+ add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
+ endif()
+
+ if (GGML_CUDA_NO_VMM)
+ add_compile_definitions(GGML_CUDA_NO_VMM)
+ endif()
+
+ if (DEFINED GGML_CUDA_DMMV_Y)
+ add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility
+ endif()
+
+ if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
+ add_compile_definitions(GGML_CUDA_F16)
+ endif()
+
+ if (GGML_CUDA_NO_PEER_COPY)
+ add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
+ endif()
+
+ if (GGML_STATIC)
+ if (WIN32)
+ # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
+ else ()
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
+ endif()
+ else()
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
+ endif()
+
+ if (GGML_CUDA_NO_VMM)
+ # No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
+ else()
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
+ endif()
+ else()
+ message(WARNING "CUDA not found")
+ endif()
+endif()
+
+if (GGML_HIPBLAS)
+ if (NOT EXISTS $ENV{ROCM_PATH})
+ if (NOT EXISTS /opt/rocm)
+ set(ROCM_PATH /usr)
+ else()
+ set(ROCM_PATH /opt/rocm)
+ endif()
+ else()
+ set(ROCM_PATH $ENV{ROCM_PATH})
+ endif()
+
+ list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
+ list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake")
+
+ # CMake on Windows doesn't support the HIP language yet
+ if (WIN32)
+ set(CXX_IS_HIPCC TRUE)
+ else()
+ string(REGEX MATCH "hipcc(\.bat)?$" CXX_IS_HIPCC "${CMAKE_CXX_COMPILER}")
+ endif()
+
+ if (CXX_IS_HIPCC)
+ if (LINUX)
+ if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
+ message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
+ endif()
+
+ message(WARNING "Setting hipcc as the C++ compiler is legacy behavior."
+ " Prefer setting the HIP compiler directly. See README for details.")
+ endif()
+ else()
+ # Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
+ if (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
+ set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
+ endif()
+ cmake_minimum_required(VERSION 3.21)
+ enable_language(HIP)
+ endif()
+
+ find_package(hip REQUIRED)
+ find_package(hipblas REQUIRED)
+ find_package(rocblas REQUIRED)
+
+ message(STATUS "HIP and hipBLAS found")
+
+ file(GLOB GGML_HEADERS_ROCM "ggml-cuda/*.cuh")
+ list(APPEND GGML_HEADERS_ROCM "../include/ggml-cuda.h")
+
+ file(GLOB GGML_SOURCES_ROCM "ggml-cuda/*.cu")
+ list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu")
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
+ file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu")
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
+
+ if (GGML_CUDA_FA_ALL_QUANTS)
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*.cu")
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
+ add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
+ else()
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
+ file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
+ endif()
+
+ list(APPEND GGML_CDEF_PUBLIC GGML_USE_CUDA)
+
+ add_compile_definitions(GGML_USE_HIPBLAS)
+ add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
+ add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
+ add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
+
+ if (GGML_HIP_UMA)
+ add_compile_definitions(GGML_HIP_UMA)
+ endif()
+
+ if (GGML_CUDA_FORCE_DMMV)
+ add_compile_definitions(GGML_CUDA_FORCE_DMMV)
+ endif()
+
+ if (GGML_CUDA_FORCE_MMQ)
+ add_compile_definitions(GGML_CUDA_FORCE_MMQ)
+ endif()
+
+ if (GGML_CUDA_FORCE_CUBLAS)
+ add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
+ endif()
+
+ if (GGML_CUDA_NO_PEER_COPY)
+ add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
+ endif()
+
+ if (CXX_IS_HIPCC)
+ set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} hip::device)
+ else()
+ set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP)
+ endif()
+
+ if (GGML_STATIC)
+ message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
+ endif()
+
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} PUBLIC hip::host roc::rocblas roc::hipblas)
+endif()
+
+if (GGML_SYCL)
+ if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA)$")
+ message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL or NVIDIA")
+ endif()
+
+ check_cxx_compiler_flag("-fsycl" SUPPORTS_SYCL)
+ if ( DEFINED ENV{ONEAPI_ROOT})
+ message(STATUS "Using oneAPI Release SYCL compiler (icpx).")
+ elseif(SUPPORTS_SYCL)
+ message(WARNING "Using open-source SYCL compiler (clang++). Didn't detect ENV {ONEAPI_ROOT}.
+ If you expected the oneAPI Release compiler, please install oneAPI & source it, like:
+ source /opt/intel/oneapi/setvars.sh")
+ else()
+ message(FATAL_ERROR, "C++ compiler lacks SYCL support.")
+ endif()
+ message(STATUS "SYCL found")
+ #todo: AOT
+
+ list(APPEND GGML_CDEF_PUBLIC GGML_USE_SYCL)
+
+ if (GGML_SYCL_F16)
+ add_compile_definitions(GGML_SYCL_F16)
+ endif()
+
+ if (GGML_CUDA_FORCE_MMQ)
+ add_compile_definitions(GGML_SYCL_FORCE_MMQ)
+ endif()
+
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl")
+
+ if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
+ add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
+ else()
+ add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
+ endif()
+
+ file(GLOB GGML_HEADERS_SYCL "ggml-sycl/*.hpp")
+ list(APPEND GGML_HEADERS_SYCL "../include/ggml-sycl.h")
+
+ file(GLOB GGML_SOURCES_SYCL "ggml-sycl/*.cpp")
+ list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp")
+
+ if (WIN32)
+ find_package(IntelSYCL REQUIRED)
+ find_package(MKL REQUIRED)
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
+ else()
+ if (GGML_SYCL_TARGET STREQUAL "INTEL")
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
+ elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl)
+ endif()
+ endif()
+endif()
+
+if (GGML_RPC)
+ message(STATUS "RPC found")
+
+ list(APPEND GGML_CDEF_PUBLIC GGML_USE_RPC)
+
+ if (WIN32)
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ws2_32)
+ endif()
+
+ set(GGML_HEADERS_RPC ../include/ggml-rpc.h)
+ set(GGML_SOURCES_RPC ggml-rpc.cpp)
+endif()
+
+if (GGML_VULKAN)
+ find_package(Vulkan COMPONENTS glslc REQUIRED)
+
+ if (Vulkan_FOUND)
+ message(STATUS "Vulkan found")
+
+ list(APPEND GGML_CDEF_PUBLIC GGML_USE_VULKAN)
+
+ # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build
+ # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector
+ if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
+ add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0)
+ endif()
+
+ if (GGML_VULKAN_CHECK_RESULTS)
+ add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
+ endif()
+
+ if (GGML_VULKAN_DEBUG)
+ add_compile_definitions(GGML_VULKAN_DEBUG)
+ endif()
+
+ if (GGML_VULKAN_MEMORY_DEBUG)
+ add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG)
+ endif()
+
+ if (GGML_VULKAN_VALIDATE)
+ add_compile_definitions(GGML_VULKAN_VALIDATE)
+ endif()
+
+ if (GGML_VULKAN_RUN_TESTS)
+ add_compile_definitions(GGML_VULKAN_RUN_TESTS)
+ endif()
+
+ add_subdirectory(vulkan-shaders)
+
+ set (_ggml_vk_genshaders_cmd vulkan-shaders-gen)
+ set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
+ set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp)
+ set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders)
+ set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv)
+
+ file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
+
+ add_custom_command(
+ OUTPUT ${_ggml_vk_header}
+ ${_ggml_vk_source}
+
+ COMMAND ${_ggml_vk_genshaders_cmd}
+ --glslc ${Vulkan_GLSLC_EXECUTABLE}
+ --input-dir ${_ggml_vk_input_dir}
+ --output-dir ${_ggml_vk_output_dir}
+ --target-hpp ${_ggml_vk_header}
+ --target-cpp ${_ggml_vk_source}
+ --no-clean
+
+ DEPENDS ${_ggml_vk_shader_deps}
+ COMMENT "Generate vulkan shaders"
+ )
+
+ set(GGML_HEADERS_VULKAN ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml-vulkan.h ${_ggml_vk_header})
+ set(GGML_SOURCES_VULKAN ggml-vulkan.cpp ${_ggml_vk_source})
+
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} Vulkan::Vulkan)
+ set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${CMAKE_CURRENT_BINARY_DIR})
+ else()
+ message(WARNING "Vulkan not found")
+ endif()
+endif()
+
+if (GGML_KOMPUTE)
+ add_compile_definitions(VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1)
+
+ find_package(Vulkan COMPONENTS glslc REQUIRED)
+ find_program(glslc_executable NAMES glslc HINTS Vulkan::glslc)
+
+ if (NOT glslc_executable)
+ message(FATAL_ERROR "glslc not found")
+ endif()
+
+ function(compile_shader)
+ set(options)
+ set(oneValueArgs)
+ set(multiValueArgs SOURCES)
+ cmake_parse_arguments(compile_shader "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
+ foreach(source ${compile_shader_SOURCES})
+ get_filename_component(filename ${source} NAME)
+ set(spv_file ${filename}.spv)
+ add_custom_command(
+ OUTPUT ${spv_file}
+ DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${source}
+ ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/common.comp
+ ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_getrows.comp
+ ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n_pre.comp
+ ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n.comp
+ COMMAND ${glslc_executable} --target-env=vulkan1.2 -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source}
+ COMMENT "Compiling ${source} to ${spv_file}"
+ )
+
+ get_filename_component(RAW_FILE_NAME ${spv_file} NAME)
+ set(FILE_NAME "shader${RAW_FILE_NAME}")
+ string(REPLACE ".comp.spv" ".h" HEADER_FILE ${FILE_NAME})
+ string(TOUPPER ${HEADER_FILE} HEADER_FILE_DEFINE)
+ string(REPLACE "." "_" HEADER_FILE_DEFINE "${HEADER_FILE_DEFINE}")
+ set(OUTPUT_HEADER_FILE "${HEADER_FILE}")
+ message(STATUS "${HEADER_FILE} generating ${HEADER_FILE_DEFINE}")
+ if(CMAKE_GENERATOR MATCHES "Visual Studio")
+ add_custom_command(
+ OUTPUT ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+ DEPENDS ${spv_file} xxd
+ COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd"
+ )
+ else()
+ add_custom_command(
+ OUTPUT ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_BINARY_DIR}/bin/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
+ DEPENDS ${spv_file} xxd
+ COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/xxd"
+ )
+ endif()
+ endforeach()
+ endfunction()
+
+ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
+ message(STATUS "Kompute found")
+ set(KOMPUTE_OPT_LOG_LEVEL Error CACHE STRING "Kompute log level")
+ add_subdirectory(kompute)
+
+ # Compile our shaders
+ compile_shader(SOURCES
+ kompute-shaders/op_scale.comp
+ kompute-shaders/op_scale_8.comp
+ kompute-shaders/op_add.comp
+ kompute-shaders/op_addrow.comp
+ kompute-shaders/op_mul.comp
+ kompute-shaders/op_silu.comp
+ kompute-shaders/op_relu.comp
+ kompute-shaders/op_gelu.comp
+ kompute-shaders/op_softmax.comp
+ kompute-shaders/op_norm.comp
+ kompute-shaders/op_rmsnorm.comp
+ kompute-shaders/op_diagmask.comp
+ kompute-shaders/op_mul_mat_mat_f32.comp
+ kompute-shaders/op_mul_mat_f16.comp
+ kompute-shaders/op_mul_mat_q8_0.comp
+ kompute-shaders/op_mul_mat_q4_0.comp
+ kompute-shaders/op_mul_mat_q4_1.comp
+ kompute-shaders/op_mul_mat_q6_k.comp
+ kompute-shaders/op_getrows_f32.comp
+ kompute-shaders/op_getrows_f16.comp
+ kompute-shaders/op_getrows_q4_0.comp
+ kompute-shaders/op_getrows_q4_1.comp
+ kompute-shaders/op_getrows_q6_k.comp
+ kompute-shaders/op_rope_f16.comp
+ kompute-shaders/op_rope_f32.comp
+ kompute-shaders/op_cpy_f16_f16.comp
+ kompute-shaders/op_cpy_f16_f32.comp
+ kompute-shaders/op_cpy_f32_f16.comp
+ kompute-shaders/op_cpy_f32_f32.comp
+ )
+
+ # Create a custom target for our generated shaders
+ add_custom_target(generated_shaders DEPENDS
+ shaderop_scale.h
+ shaderop_scale_8.h
+ shaderop_add.h
+ shaderop_addrow.h
+ shaderop_mul.h
+ shaderop_silu.h
+ shaderop_relu.h
+ shaderop_gelu.h
+ shaderop_softmax.h
+ shaderop_norm.h
+ shaderop_rmsnorm.h
+ shaderop_diagmask.h
+ shaderop_mul_mat_mat_f32.h
+ shaderop_mul_mat_f16.h
+ shaderop_mul_mat_q8_0.h
+ shaderop_mul_mat_q4_0.h
+ shaderop_mul_mat_q4_1.h
+ shaderop_mul_mat_q6_k.h
+ shaderop_getrows_f32.h
+ shaderop_getrows_f16.h
+ shaderop_getrows_q4_0.h
+ shaderop_getrows_q4_1.h
+ shaderop_getrows_q6_k.h
+ shaderop_rope_f16.h
+ shaderop_rope_f32.h
+ shaderop_cpy_f16_f16.h
+ shaderop_cpy_f16_f32.h
+ shaderop_cpy_f32_f16.h
+ shaderop_cpy_f32_f32.h
+ )
+
+ # Create a custom command that depends on the generated_shaders
+ add_custom_command(
+ OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
+ COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
+ DEPENDS generated_shaders
+ COMMENT "Ensuring shaders are generated before compiling ggml-kompute.cpp"
+ )
+
+ # Add the stamp to the main sources to ensure dependency tracking
+ set(GGML_SOURCES_KOMPUTE ggml-kompute.cpp ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp)
+ set(GGML_HEADERS_KOMPUTE ../include/ggml-kompute.h ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp)
+
+ list(APPEND GGML_CDEF_PUBLIC GGML_USE_KOMPUTE)
+
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} kompute)
+ set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${CMAKE_CURRENT_BINARY_DIR})
+ else()
+ message(WARNING "Kompute not found")
+ endif()
+endif()
+
+if (GGML_CPU_HBM)
+ find_library(memkind memkind REQUIRED)
+
+ message(STATUS "Using memkind for CPU HBM")
+
+ add_compile_definitions(GGML_USE_CPU_HBM)
+
+ target_link_libraries(ggml PUBLIC memkind)
+endif()
+
+if (GGML_CANN)
+ if ("cann${CANN_INSTALL_DIR}" STREQUAL "cann" AND DEFINED ENV{ASCEND_TOOLKIT_HOME})
+ set(CANN_INSTALL_DIR $ENV{ASCEND_TOOLKIT_HOME})
+ message(STATUS "CANN: updated CANN_INSTALL_DIR from ASCEND_TOOLKIT_HOME=$ENV{ASCEND_TOOLKIT_HOME}")
+ endif()
+
+ if (CANN_INSTALL_DIR)
+ # Only Support Linux.
+ if (GGML_CANN)
+ if (NOT UNIX)
+ set(GGML_CANN OFF)
+ message(WARNING "CANN: CANN toolkit supports unix but not ${CMAKE_SYSTEM_NAME}. Turning off GGML_CANN")
+ endif()
+ endif()
+
+ # Supported platforms: x86-64, arm64
+ if (GGML_CANN)
+ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")
+ elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "amd64")
+ else()
+ set(GGML_CANN OFF)
+ message(WARNING "CANN: CANN toolkit supports x86-64 and arm64 but not ${CMAKE_SYSTEM_PROCESSOR}. Turning off GGML_CANN")
+ endif()
+ endif()
+
+ # Set header and libs
+ if(GGML_CANN)
+ set(CANN_INCLUDE_DIRS
+ ${CANN_INSTALL_DIR}/include
+ ${CANN_INSTALL_DIR}/include/aclnn
+ ${CANN_INSTALL_DIR}/acllib/include
+ )
+
+ # TODO: find libs
+ link_directories(
+ ${CANN_INSTALL_DIR}/lib64
+ )
+
+ add_subdirectory(ggml-cann/kernels)
+ list(APPEND CANN_LIBRARIES
+ ascendcl
+ nnopbase
+ opapi
+ acl_op_compiler
+ ascendc_kernels
+ )
+
+ set(GGML_HEADERS_CANN "../include/ggml-cann.h")
+ file(GLOB GGML_SOURCES_CANN "ggml-cann/*.cpp")
+ list(APPEND GGML_SOURCES_CANN "ggml-cann.cpp")
+
+ message(STATUS "CANN: CANN_INCLUDE_DIRS = ${CANN_INCLUDE_DIRS}")
+ message(STATUS "CANN: CANN_LIBRARIES = ${CANN_LIBRARIES}")
+
+ set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ${CANN_LIBRARIES} )
+ set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${CANN_INCLUDE_DIRS})
+ list(APPEND GGML_CDEF_PUBLIC GGML_USE_CANN)
+ endif()
+ else()
+ set(GGML_CANN OFF)
+ message(WARNING "CANN: Can't find CANN_INSTALL_DIR, do you forget to source set_var.sh. Turning off GGML_CANN")
+ endif()
+
+ if(NOT GGML_CANN)
+ message(WARNING "CANN: GGML_CANN is turned OFF, see above for details.")
+ endif()
+endif()
+
+function(get_flags CCID CCVER)
+ set(C_FLAGS "")
+ set(CXX_FLAGS "")
+
+ if (CCID MATCHES "Clang")
+ set(C_FLAGS -Wunreachable-code-break -Wunreachable-code-return)
+ set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi)
+
+ if (
+ (CCID STREQUAL "Clang" AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR
+ (CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0)
+ )
+ list(APPEND C_FLAGS -Wdouble-promotion)
+ endif()
+ elseif (CCID STREQUAL "GNU")
+ set(C_FLAGS -Wdouble-promotion)
+ set(CXX_FLAGS -Wno-array-bounds)
+
+ if (CCVER VERSION_GREATER_EQUAL 7.1.0)
+ list(APPEND CXX_FLAGS -Wno-format-truncation)
+ endif()
+ if (CCVER VERSION_GREATER_EQUAL 8.1.0)
+ list(APPEND CXX_FLAGS -Wextra-semi)
+ endif()
+ endif()
+
+ set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE)
+ set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE)
+endfunction()
+
+if (GGML_FATAL_WARNINGS)
+ if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
+ list(APPEND C_FLAGS -Werror)
+ list(APPEND CXX_FLAGS -Werror)
+ elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
+ add_compile_options(/WX)
+ endif()
+endif()
+
+if (GGML_ALL_WARNINGS)
+ if (NOT MSVC)
+ list(APPEND WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function)
+ list(APPEND C_FLAGS -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes
+ -Werror=implicit-int -Werror=implicit-function-declaration)
+ list(APPEND CXX_FLAGS -Wmissing-declarations -Wmissing-noreturn)
+
+ list(APPEND C_FLAGS ${WARNING_FLAGS})
+ list(APPEND CXX_FLAGS ${WARNING_FLAGS})
+
+ get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION})
+
+ add_compile_options("$<$<COMPILE_LANGUAGE:C>:${C_FLAGS};${GF_C_FLAGS}>"
+ "$<$<COMPILE_LANGUAGE:CXX>:${CXX_FLAGS};${GF_CXX_FLAGS}>")
+ else()
+ # todo : msvc
+ set(C_FLAGS "")
+ set(CXX_FLAGS "")
+ endif()
+endif()
+
+set(CUDA_CXX_FLAGS "")
+
+if (GGML_CUDA)
+ set(CUDA_FLAGS -use_fast_math)
+
+ if (GGML_FATAL_WARNINGS)
+ list(APPEND CUDA_FLAGS -Werror all-warnings)
+ endif()
+
+ if (GGML_ALL_WARNINGS AND NOT MSVC)
+ set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c)
+ if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "")
+ list(APPEND NVCC_CMD -ccbin ${CMAKE_CUDA_HOST_COMPILER})
+ endif()
+
+ execute_process(
+ COMMAND ${NVCC_CMD} -Xcompiler --version
+ OUTPUT_VARIABLE CUDA_CCFULLVER
+ ERROR_QUIET
+ )
+
+ if (NOT CUDA_CCFULLVER MATCHES clang)
+ set(CUDA_CCID "GNU")
+ execute_process(
+ COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion"
+ OUTPUT_VARIABLE CUDA_CCVER
+ ERROR_QUIET
+ )
+ else()
+ if (CUDA_CCFULLVER MATCHES Apple)
+ set(CUDA_CCID "AppleClang")
+ else()
+ set(CUDA_CCID "Clang")
+ endif()
+ string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER})
+ endif()
+
+ message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}")
+
+ get_flags(${CUDA_CCID} ${CUDA_CCVER})
+ list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later
+ endif()
+
+ if (NOT MSVC)
+ list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
+ endif()
+endif()
+
+if (GGML_LTO)
+ include(CheckIPOSupported)
+ check_ipo_supported(RESULT result OUTPUT output)
+ if (result)
+ set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
+ else()
+ message(WARNING "IPO is not supported: ${output}")
+ endif()
+endif()
+
+if (GGML_CCACHE)
+ find_program(GGML_CCACHE_FOUND ccache)
+
+ if (GGML_CCACHE_FOUND)
+ # TODO: should not be set globally
+ set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache)
+ set(ENV{CCACHE_SLOPPINESS} time_macros)
+ message(STATUS "ccache found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
+ else()
+ message(STATUS "Warning: ccache not found - consider installing it for faster compilation or disable this warning with GGML_CCACHE=OFF")
+ endif ()
+endif()
+
+# this version of Apple ld64 is buggy
+execute_process(
+ COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v
+ ERROR_VARIABLE output
+ OUTPUT_QUIET
+)
+
+if (output MATCHES "dyld-1015\.7")
+ add_compile_definitions(HAVE_BUGGY_APPLE_LINKER)
+endif()
+
+# architecture specific
+# TODO: probably these flags need to be tweaked on some architectures
+# feel free to update the Makefile for your architecture and send a pull request or issue
+message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
+if (MSVC)
+ string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR)
+ message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}")
+else ()
+ set(CMAKE_GENERATOR_PLATFORM_LWR "")
+endif ()
+
+if (NOT MSVC)
+ if (GGML_STATIC)
+ add_link_options(-static)
+ if (MINGW)
+ add_link_options(-static-libgcc -static-libstdc++)
+ endif()
+ endif()
+ if (GGML_GPROF)
+ add_compile_options(-pg)
+ endif()
+endif()
+
+set(ARCH_FLAGS "")
+
+if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR
+ CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR
+ (NOT CMAKE_OSX_ARCHITECTURES AND
+ NOT CMAKE_GENERATOR_PLATFORM_LWR AND
+ CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$"))
+
+ message(STATUS "ARM detected")
+
+ if (MSVC)
+ add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead
+ add_compile_definitions(__ARM_NEON)
+ add_compile_definitions(__ARM_FEATURE_FMA)
+
+ set(CMAKE_REQUIRED_FLAGS_PREV ${CMAKE_REQUIRED_FLAGS})
+ string(JOIN " " CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS} "/arch:armv8.2")
+
+ check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_DOTPROD)
+ if (GGML_COMPILER_SUPPORT_DOTPROD)
+ add_compile_definitions(__ARM_FEATURE_DOTPROD)
+ endif ()
+
+ check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8)
+
+ if (GGML_COMPILER_SUPPORT_MATMUL_INT8)
+ add_compile_definitions(__ARM_FEATURE_MATMUL_INT8)
+ endif ()
+
+ check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float16_t _a; float16x8_t _s = vdupq_n_f16(_a); return 0; }" GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
+ if (GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
+ add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+ endif ()
+
+ set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV})
+ else()
+ check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)
+ if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
+ list(APPEND ARCH_FLAGS -mfp16-format=ieee)
+ endif()
+ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6")
+ # Raspberry Pi 1, Zero
+ list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access)
+ endif()
+ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7")
+ if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Android")
+ # Android armeabi-v7a
+ list(APPEND ARCH_FLAGS -mfpu=neon-vfpv4 -mno-unaligned-access -funsafe-math-optimizations)
+ else()
+ # Raspberry Pi 2
+ list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations)
+ endif()
+ endif()
+ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8")
+ # Android arm64-v8a
+ # Raspberry Pi 3, 4, Zero 2 (32-bit)
+ list(APPEND ARCH_FLAGS -mno-unaligned-access)
+ endif()
+ if (GGML_SVE)
+ list(APPEND ARCH_FLAGS -march=armv8.6-a+sve)
+ endif()
+ endif()
+elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
+ (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
+ CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
+ message(STATUS "x86 detected")
+ if (MSVC)
+ # instruction set detection for MSVC only
+ if (GGML_NATIVE)
+ # TODO: improve, should not reference files from the parent folder
+ include(../cmake/FindSIMD.cmake)
+ endif ()
+ if (GGML_AVX512)
+ list(APPEND ARCH_FLAGS /arch:AVX512)
+ # MSVC has no compile-time flags enabling specific
+ # AVX512 extensions, neither it defines the
+ # macros corresponding to the extensions.
+ # Do it manually.
+ if (GGML_AVX512_VBMI)
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
+ endif()
+ if (GGML_AVX512_VNNI)
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
+ endif()
+ if (GGML_AVX512_BF16)
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
+ endif()
+ elseif (GGML_AVX2)
+ list(APPEND ARCH_FLAGS /arch:AVX2)
+ elseif (GGML_AVX)
+ list(APPEND ARCH_FLAGS /arch:AVX)
+ endif()
+ else()
+ if (GGML_NATIVE)
+ list(APPEND ARCH_FLAGS -march=native)
+ endif()
+ if (GGML_F16C)
+ list(APPEND ARCH_FLAGS -mf16c)
+ endif()
+ if (GGML_FMA)
+ list(APPEND ARCH_FLAGS -mfma)
+ endif()
+ if (GGML_AVX)
+ list(APPEND ARCH_FLAGS -mavx)
+ endif()
+ if (GGML_AVX2)
+ list(APPEND ARCH_FLAGS -mavx2)
+ endif()
+ if (GGML_AVX512)
+ list(APPEND ARCH_FLAGS -mavx512f)
+ list(APPEND ARCH_FLAGS -mavx512bw)
+ endif()
+ if (GGML_AVX512_VBMI)
+ list(APPEND ARCH_FLAGS -mavx512vbmi)
+ endif()
+ if (GGML_AVX512_VNNI)
+ list(APPEND ARCH_FLAGS -mavx512vnni)
+ endif()
+ if (GGML_AVX512_BF16)
+ list(APPEND ARCH_FLAGS -mavx512bf16)
+ endif()
+ endif()
+elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
+ message(STATUS "PowerPC detected")
+ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
+ list(APPEND ARCH_FLAGS -mcpu=powerpc64le)
+ else()
+ list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)
+ #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
+ endif()
+elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
+ message(STATUS "loongarch64 detected")
+
+ list(APPEND ARCH_FLAGS -march=loongarch64)
+ if (GGML_LASX)
+ list(APPEND ARCH_FLAGS -mlasx)
+ endif()
+ if (GGML_LSX)
+ list(APPEND ARCH_FLAGS -mlsx)
+ endif()
+else()
+ message(STATUS "Unknown architecture")
+endif()
+
+add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
+add_compile_options("$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
+
+if (GGML_CUDA)
+ list(APPEND CUDA_CXX_FLAGS ${ARCH_FLAGS})
+ list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument
+
+ if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "")
+ list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED})
+ endif()
+
+ add_compile_options("$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS}>")
+endif()
+
+if (MINGW)
+ # Target Windows 8 for PrefetchVirtualMemory
+ add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
+endif()
+
+#
+# POSIX conformance
+#
+
+# clock_gettime came in POSIX.1b (1993)
+# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional
+# posix_memalign came in POSIX.1-2001 / SUSv3
+# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985)
+add_compile_definitions(_XOPEN_SOURCE=600)
+
+# Somehow in OpenBSD whenever POSIX conformance is specified
+# some string functions rely on locale_t availability,
+# which was introduced in POSIX.1-2008, forcing us to go higher
+if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
+ remove_definitions(-D_XOPEN_SOURCE=600)
+ add_compile_definitions(_XOPEN_SOURCE=700)
+endif()
+
+# Data types, macros and functions related to controlling CPU affinity and
+# some memory allocation are available on Linux through GNU extensions in libc
+if (CMAKE_SYSTEM_NAME MATCHES "Linux")
+ add_compile_definitions(_GNU_SOURCE)
+endif()
+
+# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
+# and on macOS its availability depends on enabling Darwin extensions
+# similarly on DragonFly, enabling BSD extensions is necessary
+if (
+ CMAKE_SYSTEM_NAME MATCHES "Darwin" OR
+ CMAKE_SYSTEM_NAME MATCHES "iOS" OR
+ CMAKE_SYSTEM_NAME MATCHES "tvOS" OR
+ CMAKE_SYSTEM_NAME MATCHES "DragonFly"
+)
+ add_compile_definitions(_DARWIN_C_SOURCE)
+endif()
+
+# alloca is a non-standard interface that is not visible on BSDs when
+# POSIX conformance is specified, but not all of them provide a clean way
+# to enable it in such cases
+if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD")
+ add_compile_definitions(__BSD_VISIBLE)
+endif()
+if (CMAKE_SYSTEM_NAME MATCHES "NetBSD")
+ add_compile_definitions(_NETBSD_SOURCE)
+endif()
+if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
+ add_compile_definitions(_BSD_SOURCE)
+endif()
+
+if (WIN32)
+ add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
+
+ if (BUILD_SHARED_LIBS)
+ # TODO: should not use this
+ set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
+ endif()
+endif()
+
+#
+# libraries
+#
+
+# ggml
+
+add_library(ggml
+ ../include/ggml.h
+ ../include/ggml-alloc.h
+ ../include/ggml-backend.h
+ ggml.c
+ ggml-alloc.c
+ ggml-backend.c
+ ggml-quants.c
+ ggml-quants.h
+ ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA}
+ ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}
+ ${GGML_SOURCES_RPC} ${GGML_HEADERS_RPC}
+ ${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA}
+ ${GGML_SOURCES_SYCL} ${GGML_HEADERS_SYCL}
+ ${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE}
+ ${GGML_SOURCES_VULKAN} ${GGML_HEADERS_VULKAN}
+ ${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM}
+ ${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS}
+ ${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
+ ${GGML_SOURCES_IQK_MM} ${GGML_HEADERS_IQK_MM}
+ ${GGML_SOURCES_IQK}
+ ${GGML_SOURCES_CANN} ${GGML_HEADERS_CANN}
+ ggml-aarch64.c ggml-aarch64.h
+ )
+
+if (EMSCRIPTEN)
+ set_target_properties(ggml PROPERTIES COMPILE_FLAGS "-msimd128")
+endif()
+
+target_compile_definitions(ggml PUBLIC ${GGML_CDEF_PUBLIC})
+target_include_directories(ggml PUBLIC ../include)
+target_include_directories(ggml PRIVATE . ${GGML_EXTRA_INCLUDES})
+target_compile_features (ggml PRIVATE c_std_11) # don't bump
+
+target_link_libraries(ggml PRIVATE Threads::Threads ${GGML_EXTRA_LIBS})
+
+find_library(MATH_LIBRARY m)
+if (MATH_LIBRARY)
+ if (NOT WIN32 OR NOT GGML_SYCL)
+ target_link_libraries(ggml PRIVATE ${MATH_LIBRARY})
+ endif()
+endif()
+
+if (BUILD_SHARED_LIBS)
+ set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON)
+ target_compile_definitions(ggml PRIVATE GGML_SHARED GGML_BUILD)
+endif()
diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c
new file mode 100644
index 00000000..af53dea1
--- /dev/null
+++ b/ggml/src/ggml-aarch64.c
@@ -0,0 +1,2193 @@
+// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#include "ggml-quants.h"
+#include "ggml-impl.h"
+
+#include <math.h>
+#include <string.h>
+#include <assert.h>
+#include <float.h>
+#include <stdlib.h> // for qsort
+#include <stdio.h> // for GGML_ASSERT
+
+#include "ggml-aarch64.h"
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Woverlength-strings"
+#endif
+
+#define UNUSED GGML_UNUSED
+
+// Functions to create the interleaved data layout formats
+
+// interleave 4 block_q4_0s in blocks of blck_size_interleave
+// returns an interleaved block_q4_0x4
+// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks
+// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave
+//
+// - in : an array of block_q4_0 pointers
+// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of
+// blck_size_interleave bytes
+// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes
+// from bias offset form to pure sign form (this saves subtract
+// operations durin unpacking)
+//
+static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
+ block_q4_0x4 out;
+
+ for (int i = 0; i < 4; i++) {
+ out.d[i] = in[i].d;
+ }
+
+ for (int i = 0; i < QK4_0 * 2; i++) {
+ int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave;
+ int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave;
+ src_offset += (i % blck_size_interleave);
+
+ out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
+ }
+
+ return out;
+}
+
+// interleave 8 block_q4_0s in blocks of blck_size_interleave
+// returns an interleaved block_q4_0x8
+// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
+// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
+static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
+ block_q4_0x8 out;
+
+ for (int i = 0; i < 8; i++) {
+ out.d[i] = in[i].d;
+ }
+
+ for (int i = 0; i < QK4_0 * 4; i++) {
+ int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave;
+ int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave;
+ src_offset += (i % blck_size_interleave);
+
+ out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
+ }
+
+ return out;
+}
+
+void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) {
+ assert(QK8_0 == 32);
+ assert(k % QK8_0 == 0);
+ const int nb = k / QK8_0;
+
+ block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
+
+#if defined(__ARM_NEON)
+ float32x4_t srcv[4][8];
+ float id[4];
+
+ for (int i = 0; i < nb; i++) {
+ float32x4_t asrcv[8];
+ float32x4_t amaxv[8];
+
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
+ for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
+ for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
+ for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
+
+ const float amax = vmaxvq_f32(amaxv[0]);
+
+ const float d = amax / ((1 << 7) - 1);
+ id[row_iter] = d ? 1.0f / d : 0.0f;
+
+ y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+ }
+
+ for (int j = 0; j < 8; j++) {
+ float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]);
+ int32x4_t vi = vcvtnq_s32_f32(v);
+ y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0);
+ y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1);
+ y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2);
+ y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3);
+
+ v = vmulq_n_f32(srcv[1][j], id[1]);
+ vi = vcvtnq_s32_f32(v);
+ y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0);
+ y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1);
+ y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2);
+ y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3);
+
+ v = vmulq_n_f32(srcv[2][j], id[2]);
+ vi = vcvtnq_s32_f32(v);
+ y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0);
+ y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1);
+ y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2);
+ y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3);
+
+ v = vmulq_n_f32(srcv[3][j], id[3]);
+ vi = vcvtnq_s32_f32(v);
+ y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0);
+ y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1);
+ y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2);
+ y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3);
+ }
+ }
+#else
+ // scalar
+ const int blck_size_interleave = 4;
+ float srcv[4][QK8_0];
+ float id[4];
+
+ for (int i = 0; i < nb; i++) {
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_0; j++) {
+ srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
+ amax = MAX(amax, fabsf(srcv[row_iter][j]));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ id[row_iter] = d ? 1.0f / d : 0.0f;
+
+ y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+ }
+
+ for (int j = 0; j < QK8_0 * 4; j++) {
+ int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
+ int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
+ src_offset += (j % blck_size_interleave);
+
+ float x0 = srcv[src_id][src_offset] * id[src_id];
+ y[i].qs[j] = roundf(x0);
+ }
+ }
+#endif
+}
+
+void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) {
+ assert(QK8_0 == 32);
+ assert(k % QK8_0 == 0);
+ const int nb = k / QK8_0;
+
+ block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
+
+#if defined(__ARM_NEON)
+ float32x4_t srcv[4][8];
+ float id[4];
+
+ for (int i = 0; i < nb; i++) {
+ float32x4_t asrcv[8];
+ float32x4_t amaxv[8];
+
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
+ for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
+ for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
+ for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
+
+ const float amax = vmaxvq_f32(amaxv[0]);
+
+ const float d = amax / ((1 << 7) - 1);
+ id[row_iter] = d ? 1.0f / d : 0.0f;
+
+ y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+ }
+
+ for (int j = 0; j < 4; j++) {
+ float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]);
+ int32x4_t vi = vcvtnq_s32_f32(v);
+ y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0);
+ y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1);
+ y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2);
+ y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3);
+ v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]);
+ vi = vcvtnq_s32_f32(v);
+ y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0);
+ y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1);
+ y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2);
+ y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3);
+
+ v = vmulq_n_f32(srcv[1][2 * j], id[1]);
+ vi = vcvtnq_s32_f32(v);
+ y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0);
+ y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1);
+ y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2);
+ y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3);
+ v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]);
+ vi = vcvtnq_s32_f32(v);
+ y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0);
+ y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1);
+ y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2);
+ y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3);
+
+ v = vmulq_n_f32(srcv[2][2 * j], id[2]);
+ vi = vcvtnq_s32_f32(v);
+ y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0);
+ y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1);
+ y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2);
+ y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3);
+ v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]);
+ vi = vcvtnq_s32_f32(v);
+ y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0);
+ y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1);
+ y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2);
+ y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3);
+
+ v = vmulq_n_f32(srcv[3][2 * j], id[3]);
+ vi = vcvtnq_s32_f32(v);
+ y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0);
+ y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1);
+ y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2);
+ y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3);
+ v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]);
+ vi = vcvtnq_s32_f32(v);
+ y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0);
+ y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1);
+ y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2);
+ y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3);
+ }
+ }
+#else
+ // scalar
+ const int blck_size_interleave = 8;
+ float srcv[4][QK8_0];
+ float id[4];
+
+ for (int i = 0; i < nb; i++) {
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_0; j++) {
+ srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
+ amax = MAX(amax, fabsf(srcv[row_iter][j]));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ id[row_iter] = d ? 1.0f / d : 0.0f;
+
+ y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+ }
+
+ for (int j = 0; j < QK8_0 * 4; j++) {
+ int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
+ int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
+ src_offset += (j % blck_size_interleave);
+
+ float x0 = srcv[src_id][src_offset] * id[src_id];
+ y[i].qs[j] = roundf(x0);
+ }
+ }
+#endif
+}
+
+void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
+ assert(nrow == 4);
+ UNUSED(nrow);
+ if (blck_size_interleave == 4) {
+ quantize_q8_0_4x4(x, vy, n_per_row);
+ } else if (blck_size_interleave == 8) {
+ quantize_q8_0_4x8(x, vy, n_per_row);
+ } else {
+ assert(false);
+ }
+}
+
+static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) {
+ assert(n_per_row % QK4_0 == 0);
+ const int nb = n_per_row / QK4_0;
+
+ void * out_ptr = NULL;
+ if (nrows_interleaved == 8) {
+ out_ptr = (block_q4_0x8 *) dst;
+ }
+ else if (nrows_interleaved == 4) {
+ out_ptr = (block_q4_0x4 *) dst;
+ }
+ assert(nrows_interleaved <= 8);
+ block_q4_0 dst_tmp[8];
+
+ for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {
+
+ for (int64_t x = 0; x < nb; x++) {
+
+ for (int i = 0; i < nrows_interleaved; i++ ) {
+ quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0);
+ }
+
+ if (nrows_interleaved == 8) {
+ *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88);
+ out_ptr = (block_q4_0x8 *) out_ptr + 1;
+ }
+ else if (nrows_interleaved == 4) {
+ *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88);
+ out_ptr = (block_q4_0x4 *) out_ptr + 1;
+ }
+ }
+ }
+
+ return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0));
+}
+
+size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ if (!quant_weights) {
+ return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4);
+ }
+ else {
+ assert(false);
+ return 0;
+ }
+}
+
+size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ if (!quant_weights) {
+ return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8);
+ }
+ else {
+ assert(false);
+ return 0;
+ }
+}
+
+size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ if (!quant_weights) {
+ return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
+ }
+ else {
+ assert(false);
+ return 0;
+ }
+}
+
+void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+ const int ncols_interleaved = 4;
+ const int blocklen = 4;
+
+ assert (n % qk == 0);
+ assert (nc % ncols_interleaved == 0);
+
+ UNUSED(s);
+ UNUSED(bs);
+ UNUSED(vx);
+ UNUSED(vy);
+ UNUSED(nr);
+ UNUSED(nc);
+ UNUSED(nb);
+ UNUSED(ncols_interleaved);
+ UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE)
+ if (svcntw() == 8) {
+ GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) &&
+ "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
+ }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+ GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) &&
+ "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance");
+#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+ const void * b_ptr = vx;
+ const void * a_ptr = vy;
+ float * res_ptr = s;
+
+ __asm__ __volatile__(
+ "movi v31.16b, #0x4\n"
+ "movi v30.16b, #0xf0\n"
+ "add %x[b_ptr], %x[b_ptr], #0x8\n"
+ "1:" // Column loop
+ "add x22, %x[a_ptr], #0x2\n"
+ "movi v29.16b, #0x0\n"
+ "mov x21, %x[nb]\n"
+ "2:" // Block loop
+ "ldr q28, [%x[b_ptr], #0x0]\n"
+ "ldr q27, [x22, #0x0]\n"
+ "movi v26.4s, #0x0\n"
+ "sub x20, x22, #0x2\n"
+ "ldr q25, [x22, #0x10]\n"
+ "ldr q24, [%x[b_ptr], #0x10]\n"
+ "sub x21, x21, #0x1\n"
+ "add x22, x22, #0x22\n"
+ "ldr q23, [%x[b_ptr], #0x20]\n"
+ "ldr q22, [%x[b_ptr], #0x30]\n"
+ "ld1r { v21.8h }, [x20]\n"
+ "ldr q20, [%x[b_ptr], #-0x8]\n"
+ "sshl v16.16b, v28.16b, v31.16b\n"
+ "and v28.16b, v28.16b, v30.16b\n"
+ "sshl v19.16b, v24.16b, v31.16b\n"
+ "and v24.16b, v24.16b, v30.16b\n"
+ "add %x[b_ptr], %x[b_ptr], #0x48\n"
+ "sshl v18.16b, v23.16b, v31.16b\n"
+ "and v23.16b, v23.16b, v30.16b\n"
+ ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n"
+ "sshl v17.16b, v22.16b, v31.16b\n"
+ "and v22.16b, v22.16b, v30.16b\n"
+ "fcvtl v21.4s, v21.4h\n"
+ "fcvtl v16.4s, v20.4h\n"
+ ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n"
+ "fmul v16.4s, v16.4s, v21.4s\n"
+ ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n"
+ ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n"
+ ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n"
+ ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n"
+ ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n"
+ ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n"
+ "scvtf v26.4s, v26.4s, #0x4\n"
+ "fmla v29.4s, v26.4s, v16.4s\n"
+ "cbnz x21, 2b\n"
+ "sub %x[nc], %x[nc], #0x4\n"
+ "str q29, [%x[res_ptr], #0x0]\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "cbnz %x[nc], 1b\n"
+ : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+ : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+ : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
+ );
+#else
+ float sumf[4];
+ int sumi;
+
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+
+ for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+ for (int l = 0; l < nb; l++) {
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+ sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+ }
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+ }
+ }
+ }
+ for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+ }
+#endif
+}
+
+void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+ const int ncols_interleaved = 4;
+ const int blocklen = 8;
+
+ assert (n % qk == 0);
+ assert (nc % ncols_interleaved == 0);
+
+ UNUSED(s);
+ UNUSED(bs);
+ UNUSED(vx);
+ UNUSED(vy);
+ UNUSED(nr);
+ UNUSED(nc);
+ UNUSED(nb);
+ UNUSED(ncols_interleaved);
+ UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE)
+ if (svcntw() == 8) {
+ GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) &&
+ "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
+ }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+ const void * b_ptr = vx;
+ const void * a_ptr = vy;
+ float * res_ptr = s;
+
+ __asm__ __volatile__(
+ "movi v2.16b, #0x4\n"
+ "movi v1.16b, #0xf0\n"
+ "add %x[b_ptr], %x[b_ptr], #0x8\n"
+ "1:" // Column loop
+ "add x23, %x[a_ptr], #0x2\n"
+ "movi v0.16b, #0x0\n"
+ "mov x22, %x[nb]\n"
+ "2:" // Block loop
+ "ldr q31, [%x[b_ptr], #0x0]\n"
+ "ldr q30, [%x[b_ptr], #0x10]\n"
+ "mov x21, x23\n"
+ "movi v29.4s, #0x0\n"
+ "ldr q28, [%x[b_ptr], #0x20]\n"
+ "ldr q27, [%x[b_ptr], #0x30]\n"
+ "movi v26.4s, #0x0\n"
+ "sub x20, x23, #0x2\n"
+ "ld1r { v25.8h }, [x20]\n"
+ "ldr q24, [%x[b_ptr], #-0x8]\n"
+ "sub x22, x22, #0x1\n"
+ "add x23, x23, #0x22\n"
+ "ld1r { v23.2d }, [x21], #0x8\n"
+ "sshl v22.16b, v31.16b, v2.16b\n"
+ "sshl v16.16b, v30.16b, v2.16b\n"
+ "add %x[b_ptr], %x[b_ptr], #0x48\n"
+ "ld1r { v21.2d }, [x21], #0x8\n"
+ "sshl v20.16b, v28.16b, v2.16b\n"
+ "sshl v19.16b, v27.16b, v2.16b\n"
+ "ld1r { v18.2d }, [x21], #0x8\n"
+ "ld1r { v17.2d }, [x21], #0x8\n"
+ "and v31.16b, v31.16b, v1.16b\n"
+ "and v30.16b, v30.16b, v1.16b\n"
+ ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n"
+ ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n"
+ "and v28.16b, v28.16b, v1.16b\n"
+ "and v27.16b, v27.16b, v1.16b\n"
+ "fcvtl v25.4s, v25.4h\n"
+ "fcvtl v16.4s, v24.4h\n"
+ ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n"
+ ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n"
+ "fmul v16.4s, v16.4s, v25.4s\n"
+ ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n"
+ ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n"
+ ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n"
+ ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n"
+ "addp v29.4s, v29.4s, v26.4s\n"
+ "scvtf v29.4s, v29.4s, #0x4\n"
+ "fmla v0.4s, v29.4s, v16.4s\n"
+ "cbnz x22, 2b\n"
+ "sub %x[nc], %x[nc], #0x4\n"
+ "str q0, [%x[res_ptr], #0x0]\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "cbnz %x[nc], 1b\n"
+ : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+ : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+ : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
+ );
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+ GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
+ "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
+ "performance");
+#else
+ float sumf[4];
+ int sumi;
+
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+
+ for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+ for (int l = 0; l < nb; l++) {
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+ sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+ }
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+ }
+ }
+ }
+ for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+ }
+#endif
+}
+
+void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+ const int ncols_interleaved = 8;
+ const int blocklen = 8;
+
+ assert (n % qk == 0);
+ assert (nc % ncols_interleaved == 0);
+
+ UNUSED(s);
+ UNUSED(bs);
+ UNUSED(vx);
+ UNUSED(vy);
+ UNUSED(nr);
+ UNUSED(nc);
+ UNUSED(nb);
+ UNUSED(ncols_interleaved);
+ UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+ if (svcntw() == 8) {
+ const void * b_ptr = vx;
+ const void * a_ptr = vy;
+ float * res_ptr = s;
+
+ __asm__ __volatile__(
+ "ptrue p0.b\n"
+ "add %x[b_ptr], %x[b_ptr], #0x10\n"
+ "1:" // Column loop
+ "add x22, %x[a_ptr], #0x2\n"
+ "mov z31.b, #0x0\n"
+ "mov x21, %x[nb]\n"
+ "2:" // Block loop
+ "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n"
+ "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n"
+ "mov z28.s, #0x0\n"
+ "mov z27.s, #0x0\n"
+ "ld1rd { z26.d }, p0/Z, [x22]\n"
+ "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n"
+ "sub x20, x22, #0x2\n"
+ "sub x21, x21, #0x1\n"
+ "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n"
+ "ld1rd { z23.d }, p0/Z, [x22, #8]\n"
+ "lsl z22.b, z30.b, #0x4\n"
+ "lsl z16.b, z29.b, #0x4\n"
+ "and z30.b, z30.b, #0xf0\n"
+ "and z29.b, z29.b, #0xf0\n"
+ "ld1rd { z21.d }, p0/Z, [x22, #16]\n"
+ "ld1rd { z20.d }, p0/Z, [x22, #24]\n"
+ "lsl z19.b, z25.b, #0x4\n"
+ "and z25.b, z25.b, #0xf0\n"
+ "ld1rh { z17.h }, p0/Z, [x20]\n"
+ "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n"
+ "sdot z28.s, z22.b, z26.b\n"
+ "sdot z27.s, z16.b, z26.b\n"
+ "lsl z16.b, z24.b, #0x4\n"
+ "add x22, x22, #0x22\n"
+ "and z24.b, z24.b, #0xf0\n"
+ "add %x[b_ptr], %x[b_ptr], #0x90\n"
+ "fcvt z17.s, p0/m, z17.h\n"
+ "fcvt z18.s, p0/m, z18.h\n"
+ "sdot z28.s, z19.b, z23.b\n"
+ "sdot z27.s, z16.b, z23.b\n"
+ "fmul z18.s, z18.s, z17.s\n"
+ "sdot z28.s, z30.b, z21.b\n"
+ "sdot z27.s, z29.b, z21.b\n"
+ "sdot z28.s, z25.b, z20.b\n"
+ "sdot z27.s, z24.b, z20.b\n"
+ "uzp1 z17.s, z28.s, z27.s\n"
+ "uzp2 z16.s, z28.s, z27.s\n"
+ "add z17.s, z17.s, z16.s\n"
+ "asr z17.s, z17.s, #0x4\n"
+ "scvtf z17.s, p0/m, z17.s\n"
+ "fmla z31.s, p0/M, z17.s, z18.s\n"
+ "cbnz x21, 2b\n"
+ "sub %x[nc], %x[nc], #0x8\n"
+ "st1w { z31.s }, p0, [%x[res_ptr]]\n"
+ "add %x[res_ptr], %x[res_ptr], #0x20\n"
+ "cbnz %x[nc], 1b\n"
+ : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+ : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+ : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+ );
+ return;
+ }
+ else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+ GGML_ASSERT((ggml_cpu_has_sve() && (svcntw() == 8)) &&
+ "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal "
+ "performance");
+ }
+ else if (ggml_cpu_has_neon()) {
+ GGML_ASSERT(((ggml_cpu_has_sve() && (svcntw() == 8)) || ggml_cpu_has_matmul_int8()) &&
+ "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 "
+ "quantization format for optimal performance");
+ }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+ GGML_ASSERT(ggml_cpu_has_sve() &&
+ "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance");
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+ GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
+ "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
+ "performance");
+#else
+ float sumf[8];
+ int sumi;
+
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+
+ for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+ for (int l = 0; l < nb; l++) {
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+ sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+ }
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+ }
+ }
+ }
+ for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+ }
+#endif
+}
+
+void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+ const int ncols_interleaved = 4;
+ const int blocklen = 4;
+
+ assert (n % qk == 0);
+ assert (nr % 4 == 0);
+ assert (nc % ncols_interleaved == 0);
+
+ UNUSED(s);
+ UNUSED(bs);
+ UNUSED(vx);
+ UNUSED(vy);
+ UNUSED(nr);
+ UNUSED(nc);
+ UNUSED(nb);
+ UNUSED(ncols_interleaved);
+ UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+ if (svcntw() == 8) {
+ GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) &&
+ "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
+ }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+ GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) &&
+ "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance");
+#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+ const void * b_ptr = vx;
+ const void * a_ptr = vy;
+ float * res_ptr = s;
+ size_t res_stride = bs * sizeof(float);
+
+ __asm__ __volatile__(
+ "mov x10, %x[nr]\n"
+ "mov x9, #0x88\n"
+ "cmp x10, #0x10\n"
+ "mul x9, %x[nb], x9\n"
+ "blt 4f\n"
+ "1:" // Row loop
+ "add x28, %x[b_ptr], #0x8\n"
+ "mov x27, %x[nc]\n"
+ "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
+ "2:" // Column loop
+ "add x25, %x[a_ptr], #0x8\n"
+ "movi v15.16b, #0x0\n"
+ "movi v19.16b, #0x0\n"
+ "mov x24, %x[nb]\n"
+ "add x23, x25, x9\n"
+ "movi v18.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "add x22, x23, x9\n"
+ "movi v11.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "add x21, x22, x9\n"
+ "movi v23.16b, #0x0\n"
+ "movi v16.16b, #0x0\n"
+ "movi v25.16b, #0x0\n"
+ "movi v7.16b, #0x0\n"
+ "movi v0.16b, #0x0\n"
+ "movi v4.16b, #0x0\n"
+ "movi v5.16b, #0x0\n"
+ "movi v21.16b, #0x0\n"
+ "movi v8.16b, #0x0\n"
+ "movi v1.16b, #0x0\n"
+ "3:" // Block loop
+ "ldr q3, [x28, #0x0]\n"
+ "ldr q31, [x25, #0x0]\n"
+ "movi v28.16b, #0x4\n"
+ "movi v10.4s, #0x0\n"
+ "ldr q22, [x28, #0x10]\n"
+ "ldr q6, [x25, #0x10]\n"
+ "movi v29.4s, #0x0\n"
+ "movi v9.4s, #0x0\n"
+ "ldr q27, [x28, #0x20]\n"
+ "ldr q30, [x28, #0x30]\n"
+ "movi v20.4s, #0x0\n"
+ "movi v24.16b, #0xf0\n"
+ "ldr d2, [x25, #-0x8]\n"
+ "ldr d26, [x23, #-0x8]\n"
+ "sshl v12.16b, v3.16b, v28.16b\n"
+ "sub x20, x28, #0x8\n"
+ "ldr d17, [x20, #0x0]\n"
+ "and v3.16b, v3.16b, v24.16b\n"
+ "subs x24, x24, #0x1\n"
+ "add x28, x28, #0x48\n"
+ ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n"
+ ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n"
+ ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n"
+ ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n"
+ "sshl v31.16b, v22.16b, v28.16b\n"
+ "and v22.16b, v22.16b, v24.16b\n"
+ "fcvtl v17.4s, v17.4h\n"
+ "fcvtl v2.4s, v2.4h\n"
+ "fcvtl v26.4s, v26.4h\n"
+ ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n"
+ ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n"
+ ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n"
+ ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n"
+ "sshl v6.16b, v27.16b, v28.16b\n"
+ "sshl v28.16b, v30.16b, v28.16b\n"
+ "and v27.16b, v27.16b, v24.16b\n"
+ "and v30.16b, v30.16b, v24.16b\n"
+ "ldr q24, [x25, #0x20]\n"
+ ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
+ ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n"
+ "ldr q24, [x25, #0x30]\n"
+ ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n"
+ ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n"
+ ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n"
+ "ldr q24, [x25, #0x40]\n"
+ ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
+ ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n"
+ "ldr q24, [x25, #0x50]\n"
+ ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n"
+ ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n"
+ ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n"
+ "ldr q24, [x25, #0x60]\n"
+ ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
+ ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n"
+ ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n"
+ "ldr q24, [x25, #0x70]\n"
+ "add x25, x25, #0x88\n"
+ ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n"
+ ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n"
+ ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n"
+ "fmul v24.4s, v17.4s, v2.s[0]\n"
+ "scvtf v10.4s, v10.4s, #0x4\n"
+ "scvtf v29.4s, v29.4s, #0x4\n"
+ "scvtf v9.4s, v9.4s, #0x4\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "fmla v15.4s, v10.4s, v24.4s\n"
+ "ldr q24, [x23, #0x0]\n"
+ "fmul v10.4s, v17.4s, v2.s[1]\n"
+ "fmla v19.4s, v29.4s, v10.4s\n"
+ "ldr q10, [x23, #0x10]\n"
+ "fmul v29.4s, v17.4s, v2.s[2]\n"
+ "fmul v2.4s, v17.4s, v2.s[3]\n"
+ "fmla v18.4s, v9.4s, v29.4s\n"
+ "movi v9.4s, #0x0\n"
+ "movi v29.4s, #0x0\n"
+ ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n"
+ "fmla v14.4s, v20.4s, v2.4s\n"
+ "movi v20.4s, #0x0\n"
+ "movi v2.4s, #0x0\n"
+ ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
+ "ldr q24, [x23, #0x20]\n"
+ ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n"
+ ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n"
+ ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n"
+ ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n"
+ "ldr q10, [x23, #0x30]\n"
+ ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n"
+ ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
+ "ldr q24, [x23, #0x40]\n"
+ ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n"
+ ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n"
+ ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n"
+ ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n"
+ "ldr q10, [x23, #0x50]\n"
+ ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n"
+ ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
+ "ldr q24, [x23, #0x60]\n"
+ ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n"
+ ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n"
+ ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n"
+ ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n"
+ "ldr q10, [x23, #0x70]\n"
+ "add x23, x23, #0x88\n"
+ ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n"
+ ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n"
+ ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
+ "ldr q24, [x22, #0x0]\n"
+ ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n"
+ ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n"
+ ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n"
+ ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n"
+ "fmul v10.4s, v17.4s, v26.s[0]\n"
+ "scvtf v9.4s, v9.4s, #0x4\n"
+ "scvtf v29.4s, v29.4s, #0x4\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "scvtf v2.4s, v2.4s, #0x4\n"
+ "fmla v11.4s, v9.4s, v10.4s\n"
+ "ldr q9, [x22, #0x10]\n"
+ "fmul v10.4s, v17.4s, v26.s[1]\n"
+ "fmla v13.4s, v29.4s, v10.4s\n"
+ "ldr d29, [x22, #-0x8]\n"
+ "fmul v10.4s, v17.4s, v26.s[2]\n"
+ "fmul v26.4s, v17.4s, v26.s[3]\n"
+ "fcvtl v29.4s, v29.4h\n"
+ "fmla v23.4s, v20.4s, v10.4s\n"
+ "movi v20.4s, #0x0\n"
+ "movi v10.4s, #0x0\n"
+ "fmla v16.4s, v2.4s, v26.4s\n"
+ "movi v26.4s, #0x0\n"
+ "movi v2.4s, #0x0\n"
+ ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
+ ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n"
+ "ldr q24, [x22, #0x20]\n"
+ ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
+ ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n"
+ ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n"
+ "ldr q9, [x22, #0x30]\n"
+ ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n"
+ ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n"
+ "ldr q24, [x22, #0x40]\n"
+ ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n"
+ ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n"
+ ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n"
+ "ldr q9, [x22, #0x50]\n"
+ ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n"
+ ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n"
+ "ldr q24, [x22, #0x60]\n"
+ ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n"
+ ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n"
+ ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n"
+ "ldr q9, [x22, #0x70]\n"
+ "add x22, x22, #0x88\n"
+ ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n"
+ ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n"
+ ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n"
+ "ldr q24, [x21, #0x0]\n"
+ ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n"
+ ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n"
+ ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n"
+ "fmul v9.4s, v17.4s, v29.s[0]\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "scvtf v10.4s, v10.4s, #0x4\n"
+ "scvtf v26.4s, v26.4s, #0x4\n"
+ "scvtf v2.4s, v2.4s, #0x4\n"
+ "fmla v25.4s, v20.4s, v9.4s\n"
+ "ldr q9, [x21, #0x10]\n"
+ "fmul v20.4s, v17.4s, v29.s[1]\n"
+ "fmla v7.4s, v10.4s, v20.4s\n"
+ "ldr d20, [x21, #-0x8]\n"
+ "fmul v10.4s, v17.4s, v29.s[2]\n"
+ "fmul v29.4s, v17.4s, v29.s[3]\n"
+ "fcvtl v20.4s, v20.4h\n"
+ "fmla v0.4s, v26.4s, v10.4s\n"
+ "movi v26.4s, #0x0\n"
+ "movi v10.4s, #0x0\n"
+ "fmla v4.4s, v2.4s, v29.4s\n"
+ "movi v2.4s, #0x0\n"
+ "movi v29.4s, #0x0\n"
+ ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n"
+ ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n"
+ "ldr q12, [x21, #0x20]\n"
+ "fmul v24.4s, v17.4s, v20.s[0]\n"
+ ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n"
+ ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n"
+ ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n"
+ "ldr q9, [x21, #0x30]\n"
+ "fmul v31.4s, v17.4s, v20.s[1]\n"
+ ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n"
+ ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n"
+ ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n"
+ ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n"
+ "ldr q12, [x21, #0x40]\n"
+ "fmul v6.4s, v17.4s, v20.s[2]\n"
+ "fmul v20.4s, v17.4s, v20.s[3]\n"
+ ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n"
+ ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n"
+ ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n"
+ "ldr q9, [x21, #0x50]\n"
+ ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n"
+ ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n"
+ ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n"
+ ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n"
+ "ldr q12, [x21, #0x60]\n"
+ ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n"
+ ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n"
+ ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n"
+ ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n"
+ "ldr q17, [x21, #0x70]\n"
+ "add x21, x21, #0x88\n"
+ ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n"
+ ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n"
+ ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n"
+ ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n"
+ ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n"
+ ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n"
+ ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n"
+ ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n"
+ "scvtf v26.4s, v26.4s, #0x4\n"
+ "scvtf v10.4s, v10.4s, #0x4\n"
+ "fmla v5.4s, v26.4s, v24.4s\n"
+ "scvtf v2.4s, v2.4s, #0x4\n"
+ "scvtf v29.4s, v29.4s, #0x4\n"
+ "fmla v21.4s, v10.4s, v31.4s\n"
+ "fmla v8.4s, v2.4s, v6.4s\n"
+ "fmla v1.4s, v29.4s, v20.4s\n"
+ "bgt 3b\n"
+ "mov x20, %x[res_ptr]\n"
+ "subs x27, x27, #0x4\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "str q15, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q19, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q18, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q14, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q11, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q13, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q23, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q16, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q25, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q7, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q0, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q4, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q5, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q21, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q8, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q1, [x20, #0x0]\n"
+ "bne 2b\n"
+ "mov x20, #0x4\n"
+ "sub x10, x10, #0x10\n"
+ "cmp x10, #0x10\n"
+ "mov %x[res_ptr], x26\n"
+ "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
+ "bge 1b\n"
+ "4:" // Row loop skip
+ "cbz x10, 9f\n"
+ "5:" // Row tail: Row loop
+ "add x24, %x[b_ptr], #0x8\n"
+ "mov x23, %x[nc]\n"
+ "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
+ "6:" // Row tail: Column loop
+ "movi v15.16b, #0x0\n"
+ "movi v19.16b, #0x0\n"
+ "add x25, %x[a_ptr], #0x8\n"
+ "mov x21, %x[nb]\n"
+ "movi v18.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "7:" // Row tail: Block loop
+ "ldr q7, [x24, #0x0]\n"
+ "ldr q5, [x25, #0x0]\n"
+ "movi v9.16b, #0x4\n"
+ "movi v4.4s, #0x0\n"
+ "ldr q3, [x24, #0x10]\n"
+ "ldr q2, [x25, #0x10]\n"
+ "movi v1.4s, #0x0\n"
+ "movi v0.4s, #0x0\n"
+ "ldr q13, [x24, #0x20]\n"
+ "ldr q31, [x25, #0x20]\n"
+ "movi v30.4s, #0x0\n"
+ "movi v29.16b, #0xf0\n"
+ "ldr q28, [x24, #0x30]\n"
+ "ldr q27, [x25, #0x30]\n"
+ "sshl v20.16b, v7.16b, v9.16b\n"
+ "sub x20, x24, #0x8\n"
+ "ldr q26, [x25, #0x40]\n"
+ "ldr q25, [x25, #0x50]\n"
+ "sshl v17.16b, v3.16b, v9.16b\n"
+ "and v7.16b, v7.16b, v29.16b\n"
+ "ldr q24, [x25, #0x60]\n"
+ "ldr q16, [x25, #0x70]\n"
+ "sshl v22.16b, v13.16b, v9.16b\n"
+ "and v3.16b, v3.16b, v29.16b\n"
+ "ldr d21, [x20, #0x0]\n"
+ "ldr d12, [x25, #-0x8]\n"
+ ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n"
+ ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n"
+ ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n"
+ ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n"
+ "sshl v9.16b, v28.16b, v9.16b\n"
+ "subs x21, x21, #0x1\n"
+ "and v13.16b, v13.16b, v29.16b\n"
+ "and v28.16b, v28.16b, v29.16b\n"
+ "add x25, x25, #0x88\n"
+ "add x24, x24, #0x48\n"
+ "fcvtl v21.4s, v21.4h\n"
+ "fcvtl v12.4s, v12.4h\n"
+ ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n"
+ ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n"
+ ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n"
+ ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n"
+ "fmul v11.4s, v21.4s, v12.s[0]\n"
+ "fmul v23.4s, v21.4s, v12.s[1]\n"
+ "fmul v17.4s, v21.4s, v12.s[2]\n"
+ ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n"
+ "fmul v6.4s, v21.4s, v12.s[3]\n"
+ ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n"
+ ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n"
+ ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n"
+ ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n"
+ ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n"
+ ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n"
+ ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n"
+ ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n"
+ ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n"
+ ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n"
+ ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n"
+ ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n"
+ ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n"
+ ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n"
+ ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n"
+ ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n"
+ ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n"
+ ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n"
+ ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n"
+ ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n"
+ ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n"
+ ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n"
+ ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n"
+ "scvtf v4.4s, v4.4s, #0x4\n"
+ "scvtf v1.4s, v1.4s, #0x4\n"
+ "scvtf v0.4s, v0.4s, #0x4\n"
+ "fmla v15.4s, v4.4s, v11.4s\n"
+ "scvtf v30.4s, v30.4s, #0x4\n"
+ "fmla v19.4s, v1.4s, v23.4s\n"
+ "fmla v18.4s, v0.4s, v17.4s\n"
+ "fmla v14.4s, v30.4s, v6.4s\n"
+ "bgt 7b\n"
+ "mov x20, %x[res_ptr]\n"
+ "cmp x10, #0x1\n"
+ "str q15, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "cmp x10, #0x2\n"
+ "str q19, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "cmp x10, #0x3\n"
+ "str q18, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "str q14, [x20, #0x0]\n"
+ "8:" // Row tail: Accumulator store skip
+ "subs x23, x23, #0x4\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "bne 6b\n"
+ "subs x10, x10, #0x4\n"
+ "add %x[a_ptr], %x[a_ptr], x9\n"
+ "mov %x[res_ptr], x22\n"
+ "bgt 5b\n"
+ "9:" // Row tail: Row loop skip
+ : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+ : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+ : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+ );
+#else
+ float sumf[4][4];
+ int sumi;
+
+ for (int y = 0; y < nr / 4; y++) {
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+ }
+ for (int l = 0; l < nb; l++) {
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+ sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+ }
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+ }
+ }
+ }
+ }
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++)
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+ }
+ }
+ }
+#endif
+}
+
+void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+ const int ncols_interleaved = 4;
+ const int blocklen = 8;
+
+ assert (n % qk == 0);
+ assert (nr % 4 == 0);
+ assert (nc % ncols_interleaved == 0);
+
+ UNUSED(s);
+ UNUSED(bs);
+ UNUSED(vx);
+ UNUSED(vy);
+ UNUSED(nr);
+ UNUSED(nc);
+ UNUSED(nb);
+ UNUSED(ncols_interleaved);
+ UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+ if (svcntw() == 8) {
+ GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) &&
+ "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance");
+ }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+ const void * b_ptr = vx;
+ const void * a_ptr = vy;
+ float * res_ptr = s;
+ size_t res_stride = bs * sizeof(float);
+
+ __asm__ __volatile__(
+ "mov x10, %x[nr]\n"
+ "mov x9, #0x88\n"
+ "cmp x10, #0x10\n"
+ "mul x9, %x[nb], x9\n"
+ "blt 4f\n"
+ "1:" // Row loop
+ "add x28, %x[b_ptr], #0x8\n"
+ "mov x27, %x[nc]\n"
+ "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
+ "2:" // Column loop
+ "add x25, %x[a_ptr], #0x8\n"
+ "movi v2.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "mov x24, %x[nb]\n"
+ "add x23, x25, x9\n"
+ "movi v12.16b, #0x0\n"
+ "movi v28.16b, #0x0\n"
+ "add x22, x23, x9\n"
+ "movi v11.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "add x21, x22, x9\n"
+ "movi v22.16b, #0x0\n"
+ "movi v23.16b, #0x0\n"
+ "movi v25.16b, #0x0\n"
+ "movi v5.16b, #0x0\n"
+ "movi v7.16b, #0x0\n"
+ "movi v4.16b, #0x0\n"
+ "movi v6.16b, #0x0\n"
+ "movi v30.16b, #0x0\n"
+ "movi v24.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "3:" // Block loop
+ "ldr q21, [x28, #0x0]\n"
+ "ldr q16, [x28, #0x10]\n"
+ "movi v1.16b, #0x4\n"
+ "movi v19.4s, #0x0\n"
+ "ldr q27, [x25, #0x0]\n"
+ "ldr q15, [x25, #0x10]\n"
+ "movi v26.4s, #0x0\n"
+ "movi v18.4s, #0x0\n"
+ "ldr q29, [x28, #0x20]\n"
+ "ldr q3, [x28, #0x30]\n"
+ "movi v17.4s, #0x0\n"
+ "movi v0.16b, #0xf0\n"
+ "ldr d20, [x25, #-0x8]\n"
+ "ldr d9, [x23, #-0x8]\n"
+ "sshl v8.16b, v21.16b, v1.16b\n"
+ "sshl v31.16b, v16.16b, v1.16b\n"
+ "and v21.16b, v21.16b, v0.16b\n"
+ "and v16.16b, v16.16b, v0.16b\n"
+ "sub x20, x28, #0x8\n"
+ "subs x24, x24, #0x1\n"
+ "add x28, x28, #0x48\n"
+ ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n"
+ ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n"
+ "ldr q27, [x25, #0x20]\n"
+ ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n"
+ ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n"
+ "sshl v15.16b, v29.16b, v1.16b\n"
+ "sshl v1.16b, v3.16b, v1.16b\n"
+ "and v29.16b, v29.16b, v0.16b\n"
+ "and v3.16b, v3.16b, v0.16b\n"
+ "ldr q0, [x25, #0x30]\n"
+ "fcvtl v20.4s, v20.4h\n"
+ ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n"
+ "fcvtl v9.4s, v9.4h\n"
+ ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n"
+ "ldr q27, [x25, #0x40]\n"
+ ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n"
+ ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n"
+ "ldr q0, [x25, #0x50]\n"
+ ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n"
+ ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n"
+ "ldr q27, [x25, #0x60]\n"
+ ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n"
+ ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n"
+ "ldr q0, [x25, #0x70]\n"
+ "add x25, x25, #0x88\n"
+ ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n"
+ ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n"
+ "ldr d27, [x20, #0x0]\n"
+ ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n"
+ ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n"
+ "fcvtl v27.4s, v27.4h\n"
+ "uzp1 v0.2d, v19.2d, v26.2d\n"
+ "uzp2 v26.2d, v19.2d, v26.2d\n"
+ "fmul v19.4s, v27.4s, v20.s[0]\n"
+ "scvtf v0.4s, v0.4s, #0x4\n"
+ "scvtf v26.4s, v26.4s, #0x4\n"
+ "fmla v2.4s, v0.4s, v19.4s\n"
+ "ldr q19, [x23, #0x0]\n"
+ "uzp1 v0.2d, v18.2d, v17.2d\n"
+ "uzp2 v18.2d, v18.2d, v17.2d\n"
+ "fmul v17.4s, v27.4s, v20.s[1]\n"
+ "scvtf v0.4s, v0.4s, #0x4\n"
+ "scvtf v18.4s, v18.4s, #0x4\n"
+ "fmla v10.4s, v26.4s, v17.4s\n"
+ "ldr q17, [x23, #0x10]\n"
+ "fmul v26.4s, v27.4s, v20.s[2]\n"
+ "fmul v20.4s, v27.4s, v20.s[3]\n"
+ "fmla v12.4s, v0.4s, v26.4s\n"
+ "ldr d0, [x22, #-0x8]\n"
+ "ldr d26, [x21, #-0x8]\n"
+ "fcvtl v0.4s, v0.4h\n"
+ "fmla v28.4s, v18.4s, v20.4s\n"
+ "movi v20.4s, #0x0\n"
+ "movi v18.4s, #0x0\n"
+ ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n"
+ ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n"
+ "ldr q19, [x23, #0x20]\n"
+ "fcvtl v26.4s, v26.4h\n"
+ ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n"
+ ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n"
+ "ldr q19, [x23, #0x40]\n"
+ ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n"
+ ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n"
+ "ldr q19, [x23, #0x60]\n"
+ ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n"
+ ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n"
+ "uzp1 v19.2d, v20.2d, v18.2d\n"
+ "scvtf v19.4s, v19.4s, #0x4\n"
+ "uzp2 v20.2d, v20.2d, v18.2d\n"
+ "fmul v18.4s, v27.4s, v9.s[0]\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "fmla v11.4s, v19.4s, v18.4s\n"
+ "ldr q18, [x22, #0x0]\n"
+ "fmul v19.4s, v27.4s, v9.s[1]\n"
+ "fmla v13.4s, v20.4s, v19.4s\n"
+ "movi v19.4s, #0x0\n"
+ "movi v20.4s, #0x0\n"
+ ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n"
+ ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n"
+ "ldr q17, [x23, #0x30]\n"
+ ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n"
+ ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n"
+ "ldr q17, [x23, #0x50]\n"
+ ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n"
+ ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n"
+ "ldr q17, [x23, #0x70]\n"
+ "add x23, x23, #0x88\n"
+ ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n"
+ ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n"
+ "uzp1 v17.2d, v19.2d, v20.2d\n"
+ "scvtf v17.4s, v17.4s, #0x4\n"
+ "uzp2 v20.2d, v19.2d, v20.2d\n"
+ "fmul v19.4s, v27.4s, v9.s[2]\n"
+ "fmul v9.4s, v27.4s, v9.s[3]\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "fmla v22.4s, v17.4s, v19.4s\n"
+ "ldr q17, [x22, #0x10]\n"
+ "movi v19.4s, #0x0\n"
+ ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n"
+ "fmla v23.4s, v20.4s, v9.4s\n"
+ "movi v20.4s, #0x0\n"
+ "movi v9.4s, #0x0\n"
+ ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n"
+ "ldr q18, [x22, #0x20]\n"
+ ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n"
+ ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n"
+ ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n"
+ "ldr q18, [x22, #0x40]\n"
+ ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n"
+ ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n"
+ "ldr q18, [x22, #0x60]\n"
+ ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n"
+ ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n"
+ "movi v18.4s, #0x0\n"
+ ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n"
+ "ldr q17, [x22, #0x30]\n"
+ ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n"
+ ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n"
+ "ldr q17, [x22, #0x50]\n"
+ ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n"
+ ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n"
+ "ldr q17, [x22, #0x70]\n"
+ "add x22, x22, #0x88\n"
+ ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n"
+ ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n"
+ "uzp1 v17.2d, v19.2d, v20.2d\n"
+ "uzp2 v20.2d, v19.2d, v20.2d\n"
+ "fmul v19.4s, v27.4s, v0.s[0]\n"
+ "scvtf v17.4s, v17.4s, #0x4\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "fmla v25.4s, v17.4s, v19.4s\n"
+ "ldr q19, [x21, #0x0]\n"
+ "fmul v17.4s, v27.4s, v0.s[1]\n"
+ "fmla v5.4s, v20.4s, v17.4s\n"
+ "ldr q17, [x21, #0x10]\n"
+ "uzp1 v20.2d, v9.2d, v18.2d\n"
+ "uzp2 v9.2d, v9.2d, v18.2d\n"
+ "fmul v18.4s, v27.4s, v0.s[2]\n"
+ "fmul v0.4s, v27.4s, v0.s[3]\n"
+ "scvtf v20.4s, v20.4s, #0x4\n"
+ "scvtf v9.4s, v9.4s, #0x4\n"
+ "fmla v7.4s, v20.4s, v18.4s\n"
+ "movi v20.4s, #0x0\n"
+ "movi v18.4s, #0x0\n"
+ ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n"
+ ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n"
+ "ldr q19, [x21, #0x20]\n"
+ "fmla v4.4s, v9.4s, v0.4s\n"
+ "movi v9.4s, #0x0\n"
+ "movi v0.4s, #0x0\n"
+ ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n"
+ "fmul v8.4s, v27.4s, v26.s[0]\n"
+ ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n"
+ "ldr q17, [x21, #0x30]\n"
+ ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n"
+ "fmul v31.4s, v27.4s, v26.s[1]\n"
+ ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n"
+ "ldr q19, [x21, #0x40]\n"
+ ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n"
+ "fmul v15.4s, v27.4s, v26.s[2]\n"
+ "fmul v27.4s, v27.4s, v26.s[3]\n"
+ ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n"
+ "ldr q1, [x21, #0x50]\n"
+ ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n"
+ ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n"
+ "ldr q26, [x21, #0x60]\n"
+ ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n"
+ ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n"
+ "ldr q21, [x21, #0x70]\n"
+ "add x21, x21, #0x88\n"
+ ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n"
+ ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n"
+ ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n"
+ ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n"
+ "uzp1 v29.2d, v20.2d, v18.2d\n"
+ "uzp2 v21.2d, v20.2d, v18.2d\n"
+ "scvtf v29.4s, v29.4s, #0x4\n"
+ "uzp1 v18.2d, v9.2d, v0.2d\n"
+ "uzp2 v16.2d, v9.2d, v0.2d\n"
+ "scvtf v21.4s, v21.4s, #0x4\n"
+ "fmla v6.4s, v29.4s, v8.4s\n"
+ "scvtf v18.4s, v18.4s, #0x4\n"
+ "scvtf v16.4s, v16.4s, #0x4\n"
+ "fmla v30.4s, v21.4s, v31.4s\n"
+ "fmla v24.4s, v18.4s, v15.4s\n"
+ "fmla v14.4s, v16.4s, v27.4s\n"
+ "bgt 3b\n"
+ "mov x20, %x[res_ptr]\n"
+ "subs x27, x27, #0x4\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "str q2, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q10, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q12, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q28, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q11, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q13, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q22, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q23, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q25, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q5, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q7, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q4, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q6, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q30, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q24, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "str q14, [x20, #0x0]\n"
+ "bne 2b\n"
+ "mov x20, #0x4\n"
+ "sub x10, x10, #0x10\n"
+ "cmp x10, #0x10\n"
+ "mov %x[res_ptr], x26\n"
+ "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
+ "bge 1b\n"
+ "4:" // Row loop skip
+ "cbz x10, 9f\n"
+ "5:" // Row tail: Row loop
+ "add x24, %x[b_ptr], #0x8\n"
+ "mov x23, %x[nc]\n"
+ "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
+ "6:" // Row tail: Column loop
+ "movi v2.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "add x25, %x[a_ptr], #0x8\n"
+ "mov x21, %x[nb]\n"
+ "movi v12.16b, #0x0\n"
+ "movi v28.16b, #0x0\n"
+ "7:" // Row tail: Block loop
+ "ldr q6, [x24, #0x0]\n"
+ "ldr q5, [x24, #0x10]\n"
+ "movi v17.16b, #0x4\n"
+ "movi v8.4s, #0x0\n"
+ "ldr q4, [x25, #0x0]\n"
+ "ldr q13, [x25, #0x10]\n"
+ "movi v27.4s, #0x0\n"
+ "movi v0.4s, #0x0\n"
+ "ldr q31, [x24, #0x20]\n"
+ "ldr q14, [x24, #0x30]\n"
+ "movi v29.4s, #0x0\n"
+ "movi v22.16b, #0xf0\n"
+ "ldr q11, [x25, #0x20]\n"
+ "ldr q23, [x25, #0x30]\n"
+ "sshl v21.16b, v6.16b, v17.16b\n"
+ "sshl v16.16b, v5.16b, v17.16b\n"
+ "ldr q20, [x25, #0x40]\n"
+ "ldr q26, [x25, #0x50]\n"
+ "and v6.16b, v6.16b, v22.16b\n"
+ "and v5.16b, v5.16b, v22.16b\n"
+ "ldr q25, [x25, #0x60]\n"
+ "ldr q3, [x25, #0x70]\n"
+ "sshl v19.16b, v31.16b, v17.16b\n"
+ "sshl v18.16b, v14.16b, v17.16b\n"
+ "ldr d17, [x25, #-0x8]\n"
+ ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n"
+ ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n"
+ "and v31.16b, v31.16b, v22.16b\n"
+ ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n"
+ ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n"
+ "and v14.16b, v14.16b, v22.16b\n"
+ "sub x20, x24, #0x8\n"
+ "ldr d16, [x20, #0x0]\n"
+ "subs x21, x21, #0x1\n"
+ "add x25, x25, #0x88\n"
+ "fcvtl v17.4s, v17.4h\n"
+ "add x24, x24, #0x48\n"
+ ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n"
+ ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n"
+ ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n"
+ ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n"
+ "fcvtl v16.4s, v16.4h\n"
+ ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n"
+ ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n"
+ "fmul v23.4s, v16.4s, v17.s[0]\n"
+ "fmul v21.4s, v16.4s, v17.s[1]\n"
+ "fmul v1.4s, v16.4s, v17.s[2]\n"
+ "fmul v20.4s, v16.4s, v17.s[3]\n"
+ ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n"
+ ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n"
+ ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n"
+ ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n"
+ ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n"
+ ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n"
+ "uzp1 v19.2d, v8.2d, v27.2d\n"
+ "uzp2 v18.2d, v8.2d, v27.2d\n"
+ "scvtf v19.4s, v19.4s, #0x4\n"
+ "uzp1 v17.2d, v0.2d, v29.2d\n"
+ "uzp2 v16.2d, v0.2d, v29.2d\n"
+ "scvtf v18.4s, v18.4s, #0x4\n"
+ "fmla v2.4s, v19.4s, v23.4s\n"
+ "scvtf v17.4s, v17.4s, #0x4\n"
+ "scvtf v16.4s, v16.4s, #0x4\n"
+ "fmla v10.4s, v18.4s, v21.4s\n"
+ "fmla v12.4s, v17.4s, v1.4s\n"
+ "fmla v28.4s, v16.4s, v20.4s\n"
+ "bgt 7b\n"
+ "mov x20, %x[res_ptr]\n"
+ "cmp x10, #0x1\n"
+ "str q2, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "cmp x10, #0x2\n"
+ "str q10, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "cmp x10, #0x3\n"
+ "str q12, [x20, #0x0]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "str q28, [x20, #0x0]\n"
+ "8:" // Row tail: Accumulator store skip
+ "subs x23, x23, #0x4\n"
+ "add %x[res_ptr], %x[res_ptr], #0x10\n"
+ "bne 6b\n"
+ "subs x10, x10, #0x4\n"
+ "add %x[a_ptr], %x[a_ptr], x9\n"
+ "mov %x[res_ptr], x22\n"
+ "bgt 5b\n"
+ "9:" // Row tail: Row loop skip
+ : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+ : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+ : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+ );
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+ GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
+ "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
+ "performance");
+#else
+ float sumf[4][4];
+ int sumi;
+
+ for (int y = 0; y < nr / 4; y++) {
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+ }
+ for (int l = 0; l < nb; l++) {
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+ sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+ }
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+ }
+ }
+ }
+ }
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++)
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+ }
+ }
+ }
+#endif
+}
+
+void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+ const int ncols_interleaved = 8;
+ const int blocklen = 8;
+
+ assert (n % qk == 0);
+ assert (nr % 4 == 0);
+ assert (nc % ncols_interleaved == 0);
+
+ UNUSED(s);
+ UNUSED(bs);
+ UNUSED(vx);
+ UNUSED(vy);
+ UNUSED(nr);
+ UNUSED(nc);
+ UNUSED(nb);
+ UNUSED(ncols_interleaved);
+ UNUSED(blocklen);
+
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__))
+ if (svcntw() == 8) {
+ const void * b_ptr = vx;
+ const void * a_ptr = vy;
+ float * res_ptr = s;
+ size_t res_stride = bs * sizeof(float);
+
+ __asm__ __volatile__(
+ "mov x20, #0x4\n"
+ "mov x13, %x[nr]\n"
+ "mov z28.s, #-0x4\n"
+ "mov x12, #0x88\n"
+ "ptrue p1.b\n"
+ "whilelt p0.s, XZR, x20\n"
+ "cmp x13, #0x10\n"
+ "mul x12, %x[nb], x12\n"
+ "blt 4f\n"
+ "1:" // Row loop
+ "add x11, %x[b_ptr], #0x10\n"
+ "mov x10, %x[nc]\n"
+ "add x9, %x[res_ptr], %x[res_stride], LSL #4\n"
+ "2:" // Column loop
+ "add x28, %x[a_ptr], #0x8\n"
+ "mov z24.b, #0x0\n"
+ "mov z15.b, #0x0\n"
+ "mov x27, %x[nb]\n"
+ "add x26, x28, x12\n"
+ "mov z12.b, #0x0\n"
+ "mov z0.b, #0x0\n"
+ "add x25, x26, x12\n"
+ "mov z13.b, #0x0\n"
+ "mov z1.b, #0x0\n"
+ "add x24, x25, x12\n"
+ "mov z20.b, #0x0\n"
+ "mov z25.b, #0x0\n"
+ "mov z11.b, #0x0\n"
+ "mov z16.b, #0x0\n"
+ "mov z19.b, #0x0\n"
+ "mov z26.b, #0x0\n"
+ "mov z8.b, #0x0\n"
+ "mov z29.b, #0x0\n"
+ "mov z27.b, #0x0\n"
+ "mov z10.b, #0x0\n"
+ "3:" // Block loop
+ "ld1b { z30.b }, p1/Z, [x11]\n"
+ "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n"
+ "mov z18.s, #0x0\n"
+ "mov z7.s, #0x0\n"
+ "ld1rqb { z3.b }, p1/Z, [x28]\n"
+ "ld1rqb { z5.b }, p1/Z, [x28, #16]\n"
+ "mov z9.s, #0x0\n"
+ "mov z22.s, #0x0\n"
+ "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n"
+ "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n"
+ "sub x20, x11, #0x10\n"
+ "sub x23, x28, #0x8\n"
+ "lsl z31.b, z30.b, #0x4\n"
+ "lsl z6.b, z21.b, #0x4\n"
+ "ld1h { z23.s }, p1/Z, [x20]\n"
+ "sub x22, x26, #0x8\n"
+ "and z30.b, z30.b, #0xf0\n"
+ "and z21.b, z21.b, #0xf0\n"
+ "sub x21, x25, #0x8\n"
+ "sub x20, x24, #0x8\n"
+ "lsl z14.b, z4.b, #0x4\n"
+ "lsl z2.b, z17.b, #0x4\n"
+ "subs x27, x27, #0x1\n"
+ "add x11, x11, #0x90\n"
+ ".inst 0x451f9872 // smmla z18.s, z3.b, z31.b\n"
+ ".inst 0x45069867 // smmla z7.s, z3.b, z6.b\n"
+ "ld1rqb { z3.b }, p1/Z, [x28, #32]\n"
+ "and z4.b, z4.b, #0xf0\n"
+ ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n"
+ ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n"
+ "ld1rqb { z5.b }, p1/Z, [x28, #48]\n"
+ "and z17.b, z17.b, #0xf0\n"
+ "fcvt z23.s, p1/m, z23.h\n"
+ ".inst 0x450e9872 // smmla z18.s, z3.b, z14.b\n"
+ ".inst 0x45029867 // smmla z7.s, z3.b, z2.b\n"
+ "ld1rqb { z3.b }, p1/Z, [x28, #64]\n"
+ ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n"
+ ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n"
+ "ld1rqb { z5.b }, p1/Z, [x28, #80]\n"
+ "fscale z23.s, p1/m, z23.s, z28.s\n"
+ ".inst 0x451e9872 // smmla z18.s, z3.b, z30.b\n"
+ ".inst 0x45159867 // smmla z7.s, z3.b, z21.b\n"
+ "ld1rqb { z3.b }, p1/Z, [x28, #96]\n"
+ ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n"
+ ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n"
+ "ld1rqb { z5.b }, p1/Z, [x28, #112]\n"
+ "add x28, x28, #0x88\n"
+ ".inst 0x45049872 // smmla z18.s, z3.b, z4.b\n"
+ ".inst 0x45119867 // smmla z7.s, z3.b, z17.b\n"
+ "ld1h { z3.s }, p0/Z, [x23]\n"
+ ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n"
+ ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n"
+ "fcvt z3.s, p1/m, z3.h\n"
+ "uzp1 z5.d, z18.d, z7.d\n"
+ "uzp2 z18.d, z18.d, z7.d\n"
+ "mov z3.q, z3.q[0]\n"
+ "uzp1 z7.d, z9.d, z22.d\n"
+ "uzp2 z22.d, z9.d, z22.d\n"
+ "fmul z9.s, z23.s, z3.s[0]\n"
+ "scvtf z5.s, p1/m, z5.s\n"
+ "scvtf z18.s, p1/m, z18.s\n"
+ "scvtf z7.s, p1/m, z7.s\n"
+ "scvtf z22.s, p1/m, z22.s\n"
+ "fmla z24.s, p1/M, z5.s, z9.s\n"
+ "ld1rqb { z5.b }, p1/Z, [x26]\n"
+ "fmul z9.s, z23.s, z3.s[1]\n"
+ "fmla z15.s, p1/M, z18.s, z9.s\n"
+ "ld1rqb { z18.b }, p1/Z, [x26, #16]\n"
+ "fmul z9.s, z23.s, z3.s[2]\n"
+ "fmul z3.s, z23.s, z3.s[3]\n"
+ "fmla z12.s, p1/M, z7.s, z9.s\n"
+ "mov z9.s, #0x0\n"
+ "ld1h { z7.s }, p0/Z, [x22]\n"
+ ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n"
+ "fmla z0.s, p1/M, z22.s, z3.s\n"
+ "mov z22.s, #0x0\n"
+ "ld1h { z3.s }, p0/Z, [x21]\n"
+ ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n"
+ "ld1rqb { z5.b }, p1/Z, [x26, #32]\n"
+ "fcvt z7.s, p1/m, z7.h\n"
+ "fcvt z3.s, p1/m, z3.h\n"
+ ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n"
+ ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n"
+ "ld1rqb { z5.b }, p1/Z, [x26, #64]\n"
+ "mov z7.q, z7.q[0]\n"
+ "mov z3.q, z3.q[0]\n"
+ ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n"
+ ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n"
+ "ld1rqb { z5.b }, p1/Z, [x26, #96]\n"
+ ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n"
+ ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n"
+ "uzp1 z5.d, z9.d, z22.d\n"
+ "scvtf z5.s, p1/m, z5.s\n"
+ "uzp2 z22.d, z9.d, z22.d\n"
+ "fmul z9.s, z23.s, z7.s[0]\n"
+ "scvtf z22.s, p1/m, z22.s\n"
+ "fmla z13.s, p1/M, z5.s, z9.s\n"
+ "ld1rqb { z9.b }, p1/Z, [x25]\n"
+ "fmul z5.s, z23.s, z7.s[1]\n"
+ "fmla z1.s, p1/M, z22.s, z5.s\n"
+ "mov z5.s, #0x0\n"
+ "mov z22.s, #0x0\n"
+ ".inst 0x451f9a45 // smmla z5.s, z18.b, z31.b\n"
+ ".inst 0x45069a56 // smmla z22.s, z18.b, z6.b\n"
+ "ld1rqb { z18.b }, p1/Z, [x26, #48]\n"
+ ".inst 0x450e9a45 // smmla z5.s, z18.b, z14.b\n"
+ ".inst 0x45029a56 // smmla z22.s, z18.b, z2.b\n"
+ "ld1rqb { z18.b }, p1/Z, [x26, #80]\n"
+ ".inst 0x451e9a45 // smmla z5.s, z18.b, z30.b\n"
+ ".inst 0x45159a56 // smmla z22.s, z18.b, z21.b\n"
+ "ld1rqb { z18.b }, p1/Z, [x26, #112]\n"
+ "add x26, x26, #0x88\n"
+ ".inst 0x45049a45 // smmla z5.s, z18.b, z4.b\n"
+ ".inst 0x45119a56 // smmla z22.s, z18.b, z17.b\n"
+ "uzp1 z18.d, z5.d, z22.d\n"
+ "scvtf z18.s, p1/m, z18.s\n"
+ "uzp2 z22.d, z5.d, z22.d\n"
+ "fmul z5.s, z23.s, z7.s[2]\n"
+ "fmul z7.s, z23.s, z7.s[3]\n"
+ "scvtf z22.s, p1/m, z22.s\n"
+ "fmla z20.s, p1/M, z18.s, z5.s\n"
+ "ld1rqb { z18.b }, p1/Z, [x25, #16]\n"
+ "ld1h { z5.s }, p0/Z, [x20]\n"
+ "fcvt z5.s, p1/m, z5.h\n"
+ "fmla z25.s, p1/M, z22.s, z7.s\n"
+ "mov z22.s, #0x0\n"
+ "mov z7.s, #0x0\n"
+ ".inst 0x451f9936 // smmla z22.s, z9.b, z31.b\n"
+ ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n"
+ "ld1rqb { z9.b }, p1/Z, [x25, #32]\n"
+ "mov z5.q, z5.q[0]\n"
+ ".inst 0x450e9936 // smmla z22.s, z9.b, z14.b\n"
+ ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n"
+ "ld1rqb { z9.b }, p1/Z, [x25, #64]\n"
+ ".inst 0x451e9936 // smmla z22.s, z9.b, z30.b\n"
+ ".inst 0x45159927 // smmla z7.s, z9.b, z21.b\n"
+ "ld1rqb { z9.b }, p1/Z, [x25, #96]\n"
+ ".inst 0x45049936 // smmla z22.s, z9.b, z4.b\n"
+ ".inst 0x45119927 // smmla z7.s, z9.b, z17.b\n"
+ "uzp1 z9.d, z22.d, z7.d\n"
+ "scvtf z9.s, p1/m, z9.s\n"
+ "uzp2 z22.d, z22.d, z7.d\n"
+ "fmul z7.s, z23.s, z3.s[0]\n"
+ "scvtf z22.s, p1/m, z22.s\n"
+ "fmla z11.s, p1/M, z9.s, z7.s\n"
+ "ld1rqb { z9.b }, p1/Z, [x24]\n"
+ "fmul z7.s, z23.s, z3.s[1]\n"
+ "fmla z16.s, p1/M, z22.s, z7.s\n"
+ "mov z22.s, #0x0\n"
+ "mov z7.s, #0x0\n"
+ ".inst 0x451f9a56 // smmla z22.s, z18.b, z31.b\n"
+ ".inst 0x45069a47 // smmla z7.s, z18.b, z6.b\n"
+ "ld1rqb { z18.b }, p1/Z, [x25, #48]\n"
+ ".inst 0x450e9a56 // smmla z22.s, z18.b, z14.b\n"
+ ".inst 0x45029a47 // smmla z7.s, z18.b, z2.b\n"
+ "ld1rqb { z18.b }, p1/Z, [x25, #80]\n"
+ ".inst 0x451e9a56 // smmla z22.s, z18.b, z30.b\n"
+ ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n"
+ "ld1rqb { z18.b }, p1/Z, [x25, #112]\n"
+ "add x25, x25, #0x88\n"
+ ".inst 0x45049a56 // smmla z22.s, z18.b, z4.b\n"
+ ".inst 0x45119a47 // smmla z7.s, z18.b, z17.b\n"
+ "uzp1 z18.d, z22.d, z7.d\n"
+ "scvtf z18.s, p1/m, z18.s\n"
+ "uzp2 z7.d, z22.d, z7.d\n"
+ "fmul z22.s, z23.s, z3.s[2]\n"
+ "fmul z3.s, z23.s, z3.s[3]\n"
+ "scvtf z7.s, p1/m, z7.s\n"
+ "fmla z19.s, p1/M, z18.s, z22.s\n"
+ "ld1rqb { z18.b }, p1/Z, [x24, #16]\n"
+ "fmul z22.s, z23.s, z5.s[0]\n"
+ "fmla z26.s, p1/M, z7.s, z3.s\n"
+ "mov z3.s, #0x0\n"
+ "mov z7.s, #0x0\n"
+ ".inst 0x451f9923 // smmla z3.s, z9.b, z31.b\n"
+ ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n"
+ "ld1rqb { z9.b }, p1/Z, [x24, #32]\n"
+ ".inst 0x450e9923 // smmla z3.s, z9.b, z14.b\n"
+ ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n"
+ "mov z9.s, #0x0\n"
+ ".inst 0x451f9a49 // smmla z9.s, z18.b, z31.b\n"
+ "mov z31.s, #0x0\n"
+ ".inst 0x45069a5f // smmla z31.s, z18.b, z6.b\n"
+ "ld1rqb { z6.b }, p1/Z, [x24, #48]\n"
+ "ld1rqb { z18.b }, p1/Z, [x24, #64]\n"
+ ".inst 0x450e98c9 // smmla z9.s, z6.b, z14.b\n"
+ "fmul z14.s, z23.s, z5.s[1]\n"
+ ".inst 0x450298df // smmla z31.s, z6.b, z2.b\n"
+ "ld1rqb { z6.b }, p1/Z, [x24, #80]\n"
+ "fmul z2.s, z23.s, z5.s[2]\n"
+ "fmul z23.s, z23.s, z5.s[3]\n"
+ ".inst 0x451e9a43 // smmla z3.s, z18.b, z30.b\n"
+ ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n"
+ "ld1rqb { z5.b }, p1/Z, [x24, #96]\n"
+ ".inst 0x451e98c9 // smmla z9.s, z6.b, z30.b\n"
+ ".inst 0x451598df // smmla z31.s, z6.b, z21.b\n"
+ "ld1rqb { z18.b }, p1/Z, [x24, #112]\n"
+ "add x24, x24, #0x88\n"
+ ".inst 0x450498a3 // smmla z3.s, z5.b, z4.b\n"
+ ".inst 0x451198a7 // smmla z7.s, z5.b, z17.b\n"
+ ".inst 0x45049a49 // smmla z9.s, z18.b, z4.b\n"
+ ".inst 0x45119a5f // smmla z31.s, z18.b, z17.b\n"
+ "uzp1 z18.d, z3.d, z7.d\n"
+ "uzp2 z5.d, z3.d, z7.d\n"
+ "scvtf z18.s, p1/m, z18.s\n"
+ "uzp1 z6.d, z9.d, z31.d\n"
+ "uzp2 z9.d, z9.d, z31.d\n"
+ "scvtf z5.s, p1/m, z5.s\n"
+ "fmla z8.s, p1/M, z18.s, z22.s\n"
+ "scvtf z6.s, p1/m, z6.s\n"
+ "scvtf z9.s, p1/m, z9.s\n"
+ "fmla z29.s, p1/M, z5.s, z14.s\n"
+ "fmla z27.s, p1/M, z6.s, z2.s\n"
+ "fmla z10.s, p1/M, z9.s, z23.s\n"
+ "bgt 3b\n"
+ "mov x20, %x[res_ptr]\n"
+ "subs x10, x10, #0x8\n"
+ "add %x[res_ptr], %x[res_ptr], #0x20\n"
+ "st1w { z24.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "st1w { z15.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "st1w { z12.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "st1w { z0.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "st1w { z13.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "st1w { z1.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "st1w { z20.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "st1w { z25.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "st1w { z11.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "st1w { z16.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "st1w { z19.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "st1w { z26.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "st1w { z8.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "st1w { z29.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "st1w { z27.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "st1w { z10.s }, p1, [x20]\n"
+ "bne 2b\n"
+ "mov x20, #0x4\n"
+ "sub x13, x13, #0x10\n"
+ "cmp x13, #0x10\n"
+ "mov %x[res_ptr], x9\n"
+ "madd %x[a_ptr], x20, x12, %x[a_ptr]\n"
+ "bge 1b\n"
+ "4:" // Row loop skip
+ "cbz x13, 9f\n"
+ "5:" // Row tail: Row loop
+ "add x25, %x[b_ptr], #0x10\n"
+ "mov x24, %x[nc]\n"
+ "add x23, %x[res_ptr], %x[res_stride], LSL #2\n"
+ "6:" // Row tail: Column loop
+ "mov z24.b, #0x0\n"
+ "mov z15.b, #0x0\n"
+ "add x28, %x[a_ptr], #0x8\n"
+ "mov x22, %x[nb]\n"
+ "mov z12.b, #0x0\n"
+ "mov z0.b, #0x0\n"
+ "7:" // Row tail: Block loop
+ "ld1b { z3.b }, p1/Z, [x25]\n"
+ "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n"
+ "mov z2.s, #0x0\n"
+ "mov z25.s, #0x0\n"
+ "ld1rqb { z26.b }, p1/Z, [x28]\n"
+ "ld1rqb { z21.b }, p1/Z, [x28, #16]\n"
+ "mov z27.s, #0x0\n"
+ "mov z19.s, #0x0\n"
+ "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n"
+ "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n"
+ "sub x21, x25, #0x10\n"
+ "sub x20, x28, #0x8\n"
+ "lsl z20.b, z3.b, #0x4\n"
+ "lsl z4.b, z6.b, #0x4\n"
+ "ld1rqb { z10.b }, p1/Z, [x28, #32]\n"
+ "ld1rqb { z23.b }, p1/Z, [x28, #48]\n"
+ "and z3.b, z3.b, #0xf0\n"
+ "and z6.b, z6.b, #0xf0\n"
+ "ld1rqb { z11.b }, p1/Z, [x28, #64]\n"
+ "ld1rqb { z7.b }, p1/Z, [x28, #80]\n"
+ "lsl z8.b, z29.b, #0x4\n"
+ "lsl z14.b, z16.b, #0x4\n"
+ "ld1rqb { z18.b }, p1/Z, [x28, #96]\n"
+ "ld1rqb { z30.b }, p1/Z, [x28, #112]\n"
+ ".inst 0x45149b42 // smmla z2.s, z26.b, z20.b\n"
+ ".inst 0x45049b59 // smmla z25.s, z26.b, z4.b\n"
+ "and z29.b, z29.b, #0xf0\n"
+ "ld1h { z17.s }, p1/Z, [x21]\n"
+ ".inst 0x45149abb // smmla z27.s, z21.b, z20.b\n"
+ ".inst 0x45049ab3 // smmla z19.s, z21.b, z4.b\n"
+ "and z16.b, z16.b, #0xf0\n"
+ "ld1h { z4.s }, p0/Z, [x20]\n"
+ "subs x22, x22, #0x1\n"
+ "add x28, x28, #0x88\n"
+ "fcvt z17.s, p1/m, z17.h\n"
+ "add x25, x25, #0x90\n"
+ ".inst 0x45089942 // smmla z2.s, z10.b, z8.b\n"
+ ".inst 0x450e9959 // smmla z25.s, z10.b, z14.b\n"
+ "fcvt z4.s, p1/m, z4.h\n"
+ ".inst 0x45089afb // smmla z27.s, z23.b, z8.b\n"
+ ".inst 0x450e9af3 // smmla z19.s, z23.b, z14.b\n"
+ "fscale z17.s, p1/m, z17.s, z28.s\n"
+ "mov z4.q, z4.q[0]\n"
+ ".inst 0x45039962 // smmla z2.s, z11.b, z3.b\n"
+ ".inst 0x45069979 // smmla z25.s, z11.b, z6.b\n"
+ "fmul z23.s, z17.s, z4.s[0]\n"
+ "fmul z9.s, z17.s, z4.s[1]\n"
+ "fmul z21.s, z17.s, z4.s[2]\n"
+ "fmul z4.s, z17.s, z4.s[3]\n"
+ ".inst 0x450398fb // smmla z27.s, z7.b, z3.b\n"
+ ".inst 0x450698f3 // smmla z19.s, z7.b, z6.b\n"
+ ".inst 0x451d9a42 // smmla z2.s, z18.b, z29.b\n"
+ ".inst 0x45109a59 // smmla z25.s, z18.b, z16.b\n"
+ ".inst 0x451d9bdb // smmla z27.s, z30.b, z29.b\n"
+ ".inst 0x45109bd3 // smmla z19.s, z30.b, z16.b\n"
+ "uzp1 z31.d, z2.d, z25.d\n"
+ "uzp2 z13.d, z2.d, z25.d\n"
+ "scvtf z31.s, p1/m, z31.s\n"
+ "uzp1 z17.d, z27.d, z19.d\n"
+ "uzp2 z18.d, z27.d, z19.d\n"
+ "scvtf z13.s, p1/m, z13.s\n"
+ "fmla z24.s, p1/M, z31.s, z23.s\n"
+ "scvtf z17.s, p1/m, z17.s\n"
+ "scvtf z18.s, p1/m, z18.s\n"
+ "fmla z15.s, p1/M, z13.s, z9.s\n"
+ "fmla z12.s, p1/M, z17.s, z21.s\n"
+ "fmla z0.s, p1/M, z18.s, z4.s\n"
+ "bgt 7b\n"
+ "mov x20, %x[res_ptr]\n"
+ "cmp x13, #0x1\n"
+ "st1w { z24.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "cmp x13, #0x2\n"
+ "st1w { z15.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "cmp x13, #0x3\n"
+ "st1w { z12.s }, p1, [x20]\n"
+ "add x20, x20, %x[res_stride]\n"
+ "ble 8f\n"
+ "st1w { z0.s }, p1, [x20]\n"
+ "8:" // Row tail: Accumulator store skip
+ "subs x24, x24, #0x8\n"
+ "add %x[res_ptr], %x[res_ptr], #0x20\n"
+ "bne 6b\n"
+ "subs x13, x13, #0x4\n"
+ "add %x[a_ptr], %x[a_ptr], x12\n"
+ "mov %x[res_ptr], x23\n"
+ "bgt 5b\n"
+ "9:" // Row tail: Row loop skip
+ : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+ : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+ : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+ );
+ return;
+ }
+ else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+ GGML_ASSERT((ggml_cpu_has_sve() && (svcntw() == 8)) &&
+ "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal "
+ "performance");
+ }
+ else if (ggml_cpu_has_neon()) {
+ GGML_ASSERT(((ggml_cpu_has_sve() && (svcntw() == 8)) || ggml_cpu_has_matmul_int8()) &&
+ "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 "
+ "quantization format for optimal performance");
+ }
+#endif
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+ GGML_ASSERT(ggml_cpu_has_sve() &&
+ "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance");
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+ GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) &&
+ "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
+ "performance");
+#else
+ float sumf[4][8];
+ int sumi;
+
+ for (int y = 0; y < nr / 4; y++) {
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+ }
+ for (int l = 0; l < nb; l++) {
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+ sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+ (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+ }
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+ }
+ }
+ }
+ }
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++)
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+ }
+ }
+ }
+#endif
+}
diff --git a/ggml/src/ggml-aarch64.h b/ggml/src/ggml-aarch64.h
new file mode 100644
index 00000000..517babaf
--- /dev/null
+++ b/ggml/src/ggml-aarch64.h
@@ -0,0 +1,39 @@
+// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
+#pragma once
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+
+#include "ggml.h"
+
+// GGML internal header
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Quantization
+void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
+
+// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
+size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+
+// GEMV
+void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+
+// GEMM
+void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+
+#ifdef __cplusplus
+}
+#endif
+
diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
new file mode 100644
index 00000000..e176b883
--- /dev/null
+++ b/ggml/src/ggml-alloc.c
@@ -0,0 +1,1042 @@
+#include "ggml-alloc.h"
+#include "ggml-backend-impl.h"
+#include "ggml.h"
+#include "ggml-impl.h"
+#include <assert.h>
+#include <limits.h>
+#include <stdarg.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MAX_FREE_BLOCKS 256
+
+//#define GGML_ALLOCATOR_DEBUG
+
+//#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__)
+#define AT_PRINTF(...)
+
+
+static bool ggml_is_view(const struct ggml_tensor * t) {
+ return t->view_src != NULL;
+}
+
+static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
+ if (a->type != b->type) {
+ return false;
+ }
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ if (a->ne[i] != b->ne[i]) {
+ return false;
+ }
+ if (a->nb[i] != b->nb[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+static bool ggml_op_can_inplace(enum ggml_op op) {
+ switch (op) {
+ case GGML_OP_SCALE:
+ case GGML_OP_DIAG_MASK_ZERO:
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_ADD:
+ case GGML_OP_ADD1:
+ case GGML_OP_SUB:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_SQR:
+ case GGML_OP_SQRT:
+ case GGML_OP_LOG:
+ case GGML_OP_UNARY:
+ case GGML_OP_ROPE:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_SOFT_MAX:
+ return true;
+
+ default:
+ return false;
+ }
+}
+
+static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
+ assert(alignment && !(alignment & (alignment - 1))); // power of 2
+ size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment;
+ return offset + align;
+}
+
+// tallocr
+
+struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer) {
+ void * base = ggml_backend_buffer_get_base(buffer);
+ size_t align = ggml_backend_buffer_get_alignment(buffer);
+
+ assert(align && !(align & (align - 1))); // power of 2
+
+ struct ggml_tallocr talloc = (struct ggml_tallocr) {
+ /*.buffer = */ buffer,
+ /*.base = */ base,
+ /*.alignment = */ align,
+ /*.offset = */ aligned_offset(base, 0, align),
+ };
+ return talloc;
+}
+
+void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor) {
+ size_t size = ggml_backend_buffer_get_alloc_size(talloc->buffer, tensor);
+ size = GGML_PAD(size, talloc->alignment);
+
+ if (talloc->offset + size > ggml_backend_buffer_get_size(talloc->buffer)) {
+ fprintf(stderr, "%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n",
+ __func__, tensor->name, size, ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset);
+ GGML_ASSERT(!"not enough space in the buffer");
+ return;
+ }
+
+ void * addr = (char *)ggml_backend_buffer_get_base(talloc->buffer) + talloc->offset;
+ talloc->offset += size;
+
+ assert(((uintptr_t)addr % talloc->alignment) == 0);
+
+ ggml_backend_tensor_alloc(talloc->buffer, tensor, addr);
+}
+
+// dynamic tensor allocator
+
+struct free_block {
+ size_t offset;
+ size_t size;
+};
+
+struct ggml_dyn_tallocr {
+ size_t alignment;
+ int n_free_blocks;
+ struct free_block free_blocks[MAX_FREE_BLOCKS];
+ size_t max_size;
+
+#ifdef GGML_ALLOCATOR_DEBUG
+ struct {
+ const struct ggml_tensor * tensor;
+ size_t offset;
+ } allocated_tensors[1024];
+#endif
+};
+
+#ifdef GGML_ALLOCATOR_DEBUG
+static void add_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, const struct ggml_tensor * tensor) {
+ for (int i = 0; i < 1024; i++) {
+ if (alloc->allocated_tensors[i].tensor == NULL) {
+ alloc->allocated_tensors[i].tensor = tensor;
+ alloc->allocated_tensors[i].offset = offset;
+ return;
+ }
+ }
+ GGML_ASSERT(!"out of allocated_tensors");
+}
+static void remove_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, const struct ggml_tensor * tensor) {
+ for (int i = 0; i < 1024; i++) {
+ if (alloc->allocated_tensors[i].offset == offset) {
+ alloc->allocated_tensors[i].tensor = NULL;
+ return;
+ }
+ }
+ fprintf(stderr, "tried to free tensor %s not found\n", tensor->name);
+ GGML_ASSERT(!"tensor not found");
+}
+#endif
+
+static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t size, const struct ggml_tensor * tensor) {
+ size = aligned_offset(NULL, size, alloc->alignment);
+
+ AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
+
+ size_t max_avail = 0;
+
+ // find the best fitting free block besides the last block
+ int best_fit_block = -1;
+ size_t best_fit_size = SIZE_MAX;
+ for (int i = 0; i < alloc->n_free_blocks - 1; i++) {
+ struct free_block * block = &alloc->free_blocks[i];
+ max_avail = MAX(max_avail, block->size);
+ if (block->size >= size && block->size <= best_fit_size) {
+ best_fit_block = i;
+ best_fit_size = block->size;
+ }
+ }
+
+ if (best_fit_block == -1) {
+ // the last block is our last resort
+ struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
+ max_avail = MAX(max_avail, block->size);
+ if (block->size >= size) {
+ best_fit_block = alloc->n_free_blocks - 1;
+ } else {
+ // this should never happen
+ fprintf(stderr, "%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n",
+ __func__, size, max_avail);
+ GGML_ASSERT(!"not enough space in the buffer");
+ GGML_UNREACHABLE();
+ }
+ }
+
+ struct free_block * block = &alloc->free_blocks[best_fit_block];
+ size_t offset = block->offset;
+ block->offset = offset + size;
+ block->size -= size;
+ if (block->size == 0) {
+ // remove block if empty
+ alloc->n_free_blocks--;
+ for (int j = best_fit_block; j < alloc->n_free_blocks; j++) {
+ alloc->free_blocks[j] = alloc->free_blocks[j+1];
+ }
+ }
+
+ AT_PRINTF("block %d, offset %zu\n", best_fit_block, offset);
+
+#ifdef GGML_ALLOCATOR_DEBUG
+ add_allocated_tensor(alloc, offset, tensor);
+ size_t cur_max = offset + size;
+ if (cur_max > alloc->max_size) {
+ // sort allocated_tensors by offset
+ for (int i = 0; i < 1024; i++) {
+ for (int j = i + 1; j < 1024; j++) {
+ if (alloc->allocated_tensors[i].offset > alloc->allocated_tensors[j].offset) {
+ const struct ggml_tensor * tmp_tensor = alloc->allocated_tensors[i].tensor;
+ size_t tmp_offset = alloc->allocated_tensors[i].offset;
+ alloc->allocated_tensors[i].tensor = alloc->allocated_tensors[j].tensor;
+ alloc->allocated_tensors[i].offset = alloc->allocated_tensors[j].offset;
+ alloc->allocated_tensors[j].tensor = tmp_tensor;
+ alloc->allocated_tensors[j].offset = tmp_offset;
+ }
+ }
+ }
+ fprintf(stderr, "max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
+ for (int i = 0; i < 1024; i++) {
+ if (alloc->allocated_tensors[i].tensor) {
+ fprintf(stderr, "%s [%zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name,
+ alloc->allocated_tensors[i].offset,
+ alloc->allocated_tensors[i].offset + ggml_nbytes(alloc->allocated_tensors[i].tensor),
+ ggml_nbytes(alloc->allocated_tensors[i].tensor) / 1024.0 / 1024.0);
+ }
+ }
+ fprintf(stderr, "\n");
+ }
+#endif
+
+ alloc->max_size = MAX(alloc->max_size, offset + size);
+
+ return offset;
+
+ GGML_UNUSED(tensor);
+}
+
+// this is a very naive implementation, but for our case the number of free blocks should be very small
+static void ggml_dyn_tallocr_free_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, size_t size, const struct ggml_tensor * tensor) {
+ size = aligned_offset(NULL, size, alloc->alignment);
+
+ AT_PRINTF("%s: freeing %s at %zu (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, offset, size, alloc->n_free_blocks);
+
+#ifdef GGML_ALLOCATOR_DEBUG
+ remove_allocated_tensor(alloc, offset, tensor);
+#endif
+
+ // see if we can merge with an existing block
+ for (int i = 0; i < alloc->n_free_blocks; i++) {
+ struct free_block * block = &alloc->free_blocks[i];
+ // check if ptr is at the end of the block
+ if (block->offset + block->size == offset) {
+ block->size += size;
+ // check if we can merge with the next block
+ if (i < alloc->n_free_blocks - 1 && block->offset + block->size == alloc->free_blocks[i+1].offset) {
+ block->size += alloc->free_blocks[i+1].size;
+ alloc->n_free_blocks--;
+ for (int j = i+1; j < alloc->n_free_blocks; j++) {
+ alloc->free_blocks[j] = alloc->free_blocks[j+1];
+ }
+ }
+ return;
+ }
+ // check if ptr is at the beginning of the block
+ if (offset + size == block->offset) {
+ block->offset = offset;
+ block->size += size;
+ // check if we can merge with the previous block
+ if (i > 0 && alloc->free_blocks[i-1].offset + alloc->free_blocks[i-1].size == block->offset) {
+ alloc->free_blocks[i-1].size += block->size;
+ alloc->n_free_blocks--;
+ for (int j = i; j < alloc->n_free_blocks; j++) {
+ alloc->free_blocks[j] = alloc->free_blocks[j+1];
+ }
+ }
+ return;
+ }
+ }
+ // otherwise, add a new block
+ GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks");
+ // insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster)
+ int insert_pos = 0;
+ while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].offset < offset) {
+ insert_pos++;
+ }
+ // shift all blocks from insert_pos onward to make room for the new block
+ for (int i = alloc->n_free_blocks; i > insert_pos; i--) {
+ alloc->free_blocks[i] = alloc->free_blocks[i-1];
+ }
+ // insert the new block
+ alloc->free_blocks[insert_pos].offset = offset;
+ alloc->free_blocks[insert_pos].size = size;
+ alloc->n_free_blocks++;
+
+ GGML_UNUSED(tensor);
+}
+
+static void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) {
+ alloc->n_free_blocks = 1;
+ alloc->free_blocks[0].offset = 0;
+ alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
+ alloc->max_size = 0;
+}
+
+static struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment) {
+ struct ggml_dyn_tallocr * alloc = (struct ggml_dyn_tallocr *)malloc(sizeof(struct ggml_dyn_tallocr));
+
+ *alloc = (struct ggml_dyn_tallocr) {
+ /*.alignment = */ alignment,
+ /*.n_free_blocks = */ 0,
+ /*.free_blocks = */ {{0}},
+ /*.max_size = */ 0,
+#ifdef GGML_ALLOCATOR_DEBUG
+ /*.allocated_tensors = */ {{0}},
+#endif
+ };
+
+ ggml_dyn_tallocr_reset(alloc);
+
+ return alloc;
+}
+
+static void ggml_dyn_tallocr_free(struct ggml_dyn_tallocr * alloc) {
+ free(alloc);
+}
+
+static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc) {
+ return alloc->max_size;
+}
+
+
+/////////////////////////////////////
+
+// graph allocator
+
+struct hash_node {
+ int n_children;
+ int n_views;
+ int buffer_id;
+ size_t offset; // offset within the buffer
+ bool allocated;
+};
+
+struct tensor_alloc {
+ int buffer_id;
+ size_t offset;
+ size_t size_max; // 0 = pre-allocated, unused, or view
+};
+
+struct leaf_alloc {
+ int buffer_id;
+ struct tensor_alloc leaf;
+};
+
+struct node_alloc {
+ struct tensor_alloc dst;
+ struct tensor_alloc src[GGML_MAX_SRC];
+};
+
+struct ggml_gallocr {
+ ggml_backend_buffer_type_t * bufts; // [n_buffers]
+ ggml_backend_buffer_t * buffers; // [n_buffers]
+ struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers]
+ int n_buffers;
+
+ struct ggml_hash_set hash_set;
+ struct hash_node * hash_values; // [hash_set.size]
+
+ struct node_alloc * node_allocs; // [n_nodes]
+ int n_nodes;
+
+ struct leaf_alloc * leaf_allocs; // [n_leafs]
+ int n_leafs;
+};
+
+ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs) {
+ ggml_gallocr_t galloc = (ggml_gallocr_t)calloc(1, sizeof(struct ggml_gallocr));
+ GGML_ASSERT(galloc != NULL);
+
+ galloc->bufts = calloc(n_bufs, sizeof(ggml_backend_buffer_type_t));
+ GGML_ASSERT(galloc->bufts != NULL);
+
+ galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t));
+ GGML_ASSERT(galloc->buffers != NULL);
+
+ galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *));
+ GGML_ASSERT(galloc->buf_tallocs != NULL);
+
+ for (int i = 0; i < n_bufs; i++) {
+ galloc->bufts[i] = bufts[i];
+ galloc->buffers[i] = NULL;
+
+ // check if the same buffer type is used multiple times and reuse the same allocator
+ for (int j = 0; j < i; j++) {
+ if (bufts[i] == bufts[j]) {
+ galloc->buf_tallocs[i] = galloc->buf_tallocs[j];
+ break;
+ }
+ }
+
+ if (galloc->buf_tallocs[i] == NULL) {
+ size_t alignment = ggml_backend_buft_get_alignment(bufts[i]);
+ galloc->buf_tallocs[i] = ggml_dyn_tallocr_new(alignment);
+ }
+ }
+ galloc->n_buffers = n_bufs;
+
+ return galloc;
+}
+
+ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft) {
+ return ggml_gallocr_new_n(&buft, 1);
+}
+
+void ggml_gallocr_free(ggml_gallocr_t galloc) {
+ if (galloc == NULL) {
+ return;
+ }
+
+ for (int i = 0; i < galloc->n_buffers; i++) {
+ if (galloc->buffers != NULL) {
+ // skip if already freed
+ bool freed = false;
+ for (int j = 0; j < i; j++) {
+ if (galloc->buffers[j] == galloc->buffers[i]) {
+ freed = true;
+ break;
+ }
+ }
+ if (!freed) {
+ ggml_backend_buffer_free(galloc->buffers[i]);
+ }
+ }
+ if (galloc->buf_tallocs != NULL) {
+ // skip if already freed
+ bool freed = false;
+ for (int j = 0; j < i; j++) {
+ if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) {
+ freed = true;
+ break;
+ }
+ }
+ if (!freed) {
+ ggml_dyn_tallocr_free(galloc->buf_tallocs[i]);
+ }
+ }
+ }
+
+ free(galloc->hash_set.keys);
+ free(galloc->hash_values);
+ free(galloc->bufts);
+ free(galloc->buffers);
+ free(galloc->buf_tallocs);
+ free(galloc->node_allocs);
+ free(galloc->leaf_allocs);
+ free(galloc);
+}
+
+typedef struct ggml_gallocr * ggml_gallocr_t;
+
+static struct hash_node * ggml_gallocr_hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) {
+ size_t i = ggml_hash_find_or_insert(galloc->hash_set, t);
+ return &galloc->hash_values[i];
+}
+
+static bool ggml_gallocr_is_own(ggml_gallocr_t galloc, struct ggml_tensor * t) {
+ return ggml_gallocr_hash_get(galloc, t)->allocated;
+}
+
+static void ggml_gallocr_set_node_offset(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id, size_t offset) {
+ struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
+ hn->buffer_id = buffer_id;
+ hn->offset = offset;
+ hn->allocated = true;
+}
+
+static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) {
+ return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
+}
+
+static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
+ struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
+
+ if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
+ hn->allocated = true;
+ assert(hn->offset == 0);
+
+ // try to reuse a parent's buffer (inplace)
+ if (ggml_op_can_inplace(node->op)) {
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ struct ggml_tensor * parent = node->src[i];
+ if (parent == NULL) {
+ continue;
+ }
+
+ // if the node's data is external, then we cannot re-use it
+ if (!ggml_gallocr_is_own(galloc, parent)) {
+ AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
+ continue;
+ }
+
+ // outputs cannot be reused
+ if (parent->flags & GGML_TENSOR_FLAG_OUTPUT || (parent->view_src != NULL && parent->view_src->flags & GGML_TENSOR_FLAG_OUTPUT)) {
+ AT_PRINTF("not reusing parent %s for %s as it is an output\n", parent->name, node->name);
+ continue;
+ }
+
+ if (!ggml_are_same_layout(node, parent)) {
+ AT_PRINTF("not reusing parent %s for %s as layouts are different\n", parent->name, node->name);
+ continue;
+ }
+
+ struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
+ if (p_hn->n_children == 1 && p_hn->n_views == 0) {
+ if (ggml_is_view(parent)) {
+ struct ggml_tensor * view_src = parent->view_src;
+ struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
+ if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
+ AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
+ assert(view_src_hn->offset == p_hn->offset);
+ hn->buffer_id = p_hn->buffer_id;
+ hn->offset = p_hn->offset;
+ p_hn->allocated = false; // avoid freeing the parent
+ view_src_hn->allocated = false;
+ return;
+ }
+ } else {
+ AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
+ hn->buffer_id = p_hn->buffer_id;
+ hn->offset = p_hn->offset;
+ p_hn->allocated = false; // avoid freeing the parent
+ return;
+ }
+ }
+ }
+ }
+ // allocate tensor from the buffer
+ struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id];
+ ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id];
+ size_t size = ggml_backend_buft_get_alloc_size(buft, node);
+ size_t offset = ggml_dyn_tallocr_alloc(alloc, size, node);
+ hn->buffer_id = buffer_id;
+ hn->offset = offset;
+ return;
+ }
+}
+
+static void ggml_gallocr_free_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
+ // graph outputs are never freed
+ if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
+ AT_PRINTF("not freeing output %s\n", node->name);
+ return;
+ }
+
+ struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
+ size_t offset = hn->offset;
+ int buffer_id = hn->buffer_id;
+ struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id];
+ ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id];
+ size_t size = ggml_backend_buft_get_alloc_size(buft, node);
+ ggml_dyn_tallocr_free_tensor(alloc, offset, size, node);
+ hn->allocated = false;
+}
+
+static int get_node_buffer_id(const int * node_buffer_ids, int i) {
+ return node_buffer_ids ? node_buffer_ids[i] : 0;
+}
+
+static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {
+ // clear hash tables
+ memset(galloc->hash_set.keys, 0, galloc->hash_set.size * sizeof(struct ggml_tensor *));
+ memset(galloc->hash_values, 0, galloc->hash_set.size * sizeof(struct hash_node));
+
+ // allocate leafs
+ // these may be tensors that the application is not using in the graph, but may still want to allocate for other purposes
+ for (int i = 0; i < graph->n_leafs; i++) {
+ struct ggml_tensor * leaf = graph->leafs[i];
+ ggml_gallocr_allocate_node(galloc, leaf, get_node_buffer_id(leaf_buffer_ids, i));
+ }
+
+ // count number of children and views
+ // allocate other graph inputs and leafs first to avoid overwriting them
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+
+ // TODO: better way to add external dependencies
+ // GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to
+ // control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node
+ // itself is never used and should not be considered a dependency
+ if (ggml_is_view(node) && node->op != GGML_OP_NONE) {
+ struct ggml_tensor * view_src = node->view_src;
+ ggml_gallocr_hash_get(galloc, view_src)->n_views += 1;
+ }
+
+ if (node->flags & GGML_TENSOR_FLAG_INPUT) {
+ ggml_gallocr_allocate_node(galloc, graph->nodes[i], get_node_buffer_id(node_buffer_ids, i));
+ }
+
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ continue;
+ }
+
+ ggml_gallocr_hash_get(galloc, src)->n_children += 1;
+
+ // allocate explicit inputs
+ if (src->flags & GGML_TENSOR_FLAG_INPUT) {
+ ggml_gallocr_allocate_node(galloc, src, get_node_buffer_id(node_buffer_ids, i));
+ }
+ }
+ }
+
+ // allocate tensors
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ int buffer_id = get_node_buffer_id(node_buffer_ids, i);
+
+ // allocate parents (only leafs need to be allocated at this point)
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * parent = node->src[j];
+ if (parent == NULL) {
+ continue;
+ }
+ ggml_gallocr_allocate_node(galloc, parent, buffer_id);
+ }
+
+ // allocate node
+ ggml_gallocr_allocate_node(galloc, node, buffer_id);
+
+ AT_PRINTF("exec: %s (%s) <= ", ggml_op_desc(node), node->name);
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * parent = node->src[j];
+ if (parent == NULL) {
+ continue;
+ }
+ AT_PRINTF("%s", parent->name);
+ if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
+ AT_PRINTF(", ");
+ }
+ }
+ AT_PRINTF("\n");
+
+ // update parents
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * parent = node->src[j];
+ if (parent == NULL) {
+ continue;
+ }
+ struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
+ p_hn->n_children -= 1;
+
+ AT_PRINTF("parent %s: %d children, %d views, allocated: %d\n",
+ parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated);
+
+ if (p_hn->n_children == 0 && p_hn->n_views == 0) {
+ if (ggml_is_view(parent)) {
+ struct ggml_tensor * view_src = parent->view_src;
+ struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
+ view_src_hn->n_views -= 1;
+ AT_PRINTF("view_src %s: %d children, %d views\n",
+ view_src->name, view_src_hn->n_children, view_src_hn->n_views);
+ if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src_hn->allocated) {
+ ggml_gallocr_free_node(galloc, view_src);
+ }
+ }
+ else if (p_hn->allocated) {
+ ggml_gallocr_free_node(galloc, parent);
+ }
+ }
+ AT_PRINTF("\n");
+ }
+ }
+}
+
+bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {
+ size_t hash_size = graph->visited_hash_table.size;
+
+ // initialize hash table
+ if (galloc->hash_set.size < hash_size) {
+ free(galloc->hash_set.keys);
+ free(galloc->hash_values);
+ galloc->hash_set.size = hash_size;
+ galloc->hash_set.keys = calloc(hash_size, sizeof(struct ggml_tensor *));
+ galloc->hash_values = calloc(hash_size, sizeof(struct hash_node));
+ GGML_ASSERT(galloc->hash_set.keys != NULL);
+ GGML_ASSERT(galloc->hash_values != NULL);
+ } else {
+ // reset hash table
+ memset(galloc->hash_set.keys, 0, sizeof(struct ggml_tensor *) * galloc->hash_set.size);
+ memset(galloc->hash_values, 0, sizeof(struct hash_node) * galloc->hash_set.size);
+ }
+
+ // reset allocators
+ for (int i = 0; i < galloc->n_buffers; i++) {
+ ggml_dyn_tallocr_reset(galloc->buf_tallocs[i]);
+ }
+
+ // allocate in hash table
+ ggml_gallocr_alloc_graph_impl(galloc, graph, node_buffer_ids, leaf_buffer_ids);
+
+ // set the node_allocs from the hash table
+ if (galloc->n_nodes < graph->n_nodes) {
+ free(galloc->node_allocs);
+ galloc->node_allocs = calloc(graph->n_nodes, sizeof(struct node_alloc));
+ GGML_ASSERT(galloc->node_allocs != NULL);
+ }
+ galloc->n_nodes = graph->n_nodes;
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ struct node_alloc * node_alloc = &galloc->node_allocs[i];
+ if (node->view_src || node->data) {
+ node_alloc->dst.buffer_id = -1;
+ node_alloc->dst.offset = SIZE_MAX;
+ node_alloc->dst.size_max = 0;
+ } else {
+ struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
+ node_alloc->dst.buffer_id = hn->buffer_id;
+ node_alloc->dst.offset = hn->offset;
+ node_alloc->dst.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
+ }
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (!src || src->view_src || src->data) {
+ node_alloc->src[j].buffer_id = -1;
+ node_alloc->src[j].offset = SIZE_MAX;
+ node_alloc->src[j].size_max = 0;
+ } else {
+ struct hash_node * hn = ggml_gallocr_hash_get(galloc, src);
+ node_alloc->src[j].buffer_id = hn->buffer_id;
+ node_alloc->src[j].offset = hn->offset;
+ node_alloc->src[j].size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], src);
+ }
+ }
+ }
+ if (galloc->n_leafs < graph->n_leafs) {
+ free(galloc->leaf_allocs);
+ galloc->leaf_allocs = calloc(graph->n_leafs, sizeof(galloc->leaf_allocs[0]));
+ GGML_ASSERT(galloc->leaf_allocs != NULL);
+ }
+ galloc->n_leafs = graph->n_leafs;
+ for (int i = 0; i < graph->n_leafs; i++) {
+ struct ggml_tensor * leaf = graph->leafs[i];
+ struct hash_node * hn = ggml_gallocr_hash_get(galloc, leaf);
+ galloc->leaf_allocs[i].buffer_id = hn->buffer_id;
+ if (leaf->view_src || leaf->data) {
+ galloc->leaf_allocs[i].leaf.buffer_id = -1;
+ galloc->leaf_allocs[i].leaf.offset = SIZE_MAX;
+ galloc->leaf_allocs[i].leaf.size_max = 0;
+ } else {
+ galloc->leaf_allocs[i].leaf.buffer_id = hn->buffer_id;
+ galloc->leaf_allocs[i].leaf.offset = hn->offset;
+ galloc->leaf_allocs[i].leaf.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], leaf);
+ }
+ }
+
+ // reallocate buffers if needed
+ for (int i = 0; i < galloc->n_buffers; i++) {
+ // if the buffer type is used multiple times, we reuse the same buffer
+ for (int j = 0; j < i; j++) {
+ if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) {
+ galloc->buffers[i] = galloc->buffers[j];
+ break;
+ }
+ }
+
+ size_t cur_size = galloc->buffers[i] ? ggml_backend_buffer_get_size(galloc->buffers[i]) : 0;
+ size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]);
+
+ // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views
+ if (new_size > cur_size || galloc->buffers[i] == NULL) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
+#endif
+
+ ggml_backend_buffer_free(galloc->buffers[i]);
+ galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size);
+ if (galloc->buffers[i] == NULL) {
+ fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
+ return false;
+ }
+ ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
+ }
+ }
+
+ return true;
+}
+
+bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
+ return ggml_gallocr_reserve_n(galloc, graph, NULL, NULL);
+}
+
+static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * tensor, struct tensor_alloc * tensor_alloc) {
+ int buffer_id = tensor_alloc->buffer_id;
+ assert(tensor->data || tensor->view_src || ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max);
+
+ if (tensor->view_src != NULL) {
+ if (tensor->buffer == NULL) {
+ assert(tensor_alloc->offset == SIZE_MAX);
+ if (tensor->view_src->buffer == NULL) {
+ // this tensor was allocated without ggml-backend
+ return;
+ }
+ ggml_backend_view_init(tensor);
+ }
+ } else {
+ if (tensor->data == NULL) {
+ assert(tensor_alloc->offset != SIZE_MAX);
+ assert(ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max);
+ void * base = ggml_backend_buffer_get_base(galloc->buffers[buffer_id]);
+ void * addr = (char *)base + tensor_alloc->offset;
+ ggml_backend_tensor_alloc(galloc->buffers[buffer_id], tensor, addr);
+ } else {
+ if (tensor->buffer == NULL) {
+ // this tensor was allocated without ggml-backend
+ return;
+ }
+ }
+ }
+}
+
+static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) {
+ ggml_backend_buffer_type_t buft = talloc->buffer_id != -1 ? galloc->bufts[talloc->buffer_id] : NULL;
+ size_t node_size = (node->data || node->view_src) ? 0 : ggml_backend_buft_get_alloc_size(buft, node);
+ return talloc->size_max >= node_size;
+}
+
+static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph * graph) {
+ if (galloc->n_nodes != graph->n_nodes) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: graph has different number of nodes\n", __func__);
+#endif
+ return true;
+ }
+
+ if (galloc->n_leafs != graph->n_leafs) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: graph has different number of leafs\n", __func__);
+#endif
+ return true;
+ }
+
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ struct node_alloc * node_alloc = &galloc->node_allocs[i];
+
+ if (!ggml_gallocr_node_needs_realloc(galloc, node, &node_alloc->dst)) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: node %s is not valid\n", __func__, node->name);
+#endif
+ return true;
+ }
+
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ continue;
+ }
+ if (!ggml_gallocr_node_needs_realloc(galloc, src, &node_alloc->src[j])) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name);
+#endif
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph) {
+ if (ggml_gallocr_needs_realloc(galloc, graph)) {
+ if (galloc->n_buffers == 1) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: reallocating buffers automatically\n", __func__);
+#endif
+ if (!ggml_gallocr_reserve(galloc, graph)) {
+ return false;
+ }
+ } else {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: cannot reallocate multi buffer graph automatically, call reserve\n", __func__);
+#endif
+ return false;
+ }
+ }
+
+ // reset buffers
+ for (int i = 0; i < galloc->n_buffers; i++) {
+ if (galloc->buffers[i] != NULL) {
+ ggml_backend_buffer_reset(galloc->buffers[i]);
+ }
+ }
+
+ // allocate the graph tensors from the previous assignments
+ // leafs
+ for (int i = 0; i < graph->n_leafs; i++) {
+ struct ggml_tensor * leaf = graph->leafs[i];
+ struct leaf_alloc * leaf_alloc = &galloc->leaf_allocs[i];
+ ggml_gallocr_init_tensor(galloc, leaf, &leaf_alloc->leaf);
+ }
+ // nodes
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ struct node_alloc * node_alloc = &galloc->node_allocs[i];
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ continue;
+ }
+ ggml_gallocr_init_tensor(galloc, src, &node_alloc->src[j]);
+ }
+ ggml_gallocr_init_tensor(galloc, node, &node_alloc->dst);
+ }
+
+ return true;
+}
+
+size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
+ GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers);
+
+ if (galloc->buffers[buffer_id] == NULL) {
+ return 0;
+ }
+
+ for (int i = 0; i < buffer_id; i++) {
+ if (galloc->buffers[i] == galloc->buffers[buffer_id]) {
+ // this buffer is the same as a previous one due to the same buffer type being used multiple times
+ // only return the buffer size the first time it appears to avoid double counting
+ return 0;
+ }
+ }
+
+ return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]);
+}
+
+// utils
+
+static bool alloc_tensor_range(struct ggml_context * ctx,
+ struct ggml_tensor * first, struct ggml_tensor * last,
+ ggml_backend_buffer_type_t buft, size_t size,
+ ggml_backend_buffer_t ** buffers, size_t * n_buffers) {
+ ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
+ if (buffer == NULL) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size);
+#endif
+ for (size_t i = 0; i < *n_buffers; i++) {
+ ggml_backend_buffer_free((*buffers)[i]);
+ }
+ free(*buffers);
+ return false;
+ }
+
+ struct ggml_tallocr tallocr = ggml_tallocr_new(buffer);
+
+ for (struct ggml_tensor * t = first; t != last; t = ggml_get_next_tensor(ctx, t)) {
+ if (t->data == NULL) {
+ if (t->view_src == NULL) {
+ ggml_tallocr_alloc(&tallocr, t);
+ } else if (t->buffer == NULL) {
+ ggml_backend_view_init(t);
+ }
+ } else {
+ if (t->view_src != NULL && t->buffer == NULL) {
+ // view of a pre-allocated tensor
+ ggml_backend_view_init(t);
+ }
+ }
+ }
+
+ *buffers = realloc(*buffers, sizeof(ggml_backend_buffer_t) * (*n_buffers + 1));
+ (*buffers)[(*n_buffers)++] = buffer;
+
+ return true;
+}
+
+ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
+ GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
+
+ size_t alignment = ggml_backend_buft_get_alignment(buft);
+ size_t max_size = ggml_backend_buft_get_max_size(buft);
+
+ ggml_backend_buffer_t * buffers = NULL;
+ size_t n_buffers = 0;
+
+ size_t cur_buf_size = 0;
+ struct ggml_tensor * first = ggml_get_first_tensor(ctx);
+ for (struct ggml_tensor * t = first; t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ size_t this_size = 0;
+ if (t->data == NULL && t->view_src == NULL) {
+ this_size = GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
+ }
+
+ if (this_size > max_size) {
+ fprintf(stderr, "%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n",
+ __func__, t->name,
+ ggml_backend_buft_name(buft),
+ this_size, max_size);
+ for (size_t i = 0; i < n_buffers; i++) {
+ ggml_backend_buffer_free(buffers[i]);
+ }
+ free(buffers);
+ return NULL;
+ }
+
+ if ((cur_buf_size + this_size) > max_size) {
+ // allocate tensors in the current buffer
+ if (!alloc_tensor_range(ctx, first, t, buft, cur_buf_size, &buffers, &n_buffers)) {
+ return NULL;
+ }
+ first = t;
+ cur_buf_size = this_size;
+ } else {
+ cur_buf_size += this_size;
+ }
+ }
+
+ // allocate remaining tensors
+ if (cur_buf_size > 0) {
+ if (!alloc_tensor_range(ctx, first, NULL, buft, cur_buf_size, &buffers, &n_buffers)) {
+ return NULL;
+ }
+ }
+
+ if (n_buffers == 0) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: all tensors in the context are already allocated\n", __func__);
+#endif
+ return NULL;
+ }
+
+ ggml_backend_buffer_t buffer;
+ if (n_buffers == 1) {
+ buffer = buffers[0];
+ } else {
+ buffer = ggml_backend_multi_buffer_alloc_buffer(buffers, n_buffers);
+ }
+ free(buffers);
+ return buffer;
+}
+
+ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
+ return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
+}
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
new file mode 100644
index 00000000..36ca3708
--- /dev/null
+++ b/ggml/src/ggml-backend-impl.h
@@ -0,0 +1,153 @@
+#pragma once
+
+// ggml-backend internal header
+
+#include "ggml-backend.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+ //
+ // Backend buffer
+ //
+
+ // buffer type
+ typedef void * ggml_backend_buffer_type_context_t;
+
+ struct ggml_backend_buffer_type_i {
+ const char * (*GGML_CALL get_name) (ggml_backend_buffer_type_t buft);
+ // allocate a buffer of this type
+ ggml_backend_buffer_t (*GGML_CALL alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size);
+ // tensor alignment
+ size_t (*GGML_CALL get_alignment) (ggml_backend_buffer_type_t buft);
+ // max buffer size that can be allocated
+ size_t (*GGML_CALL get_max_size) (ggml_backend_buffer_type_t buft);
+ // data size needed to allocate the tensor, including padding
+ size_t (*GGML_CALL get_alloc_size) (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
+ // check if tensor data is in host memory
+ bool (*GGML_CALL is_host) (ggml_backend_buffer_type_t buft);
+ };
+
+ struct ggml_backend_buffer_type {
+ struct ggml_backend_buffer_type_i iface;
+ ggml_backend_buffer_type_context_t context;
+ };
+
+ // buffer
+ typedef void * ggml_backend_buffer_context_t;
+
+ struct ggml_backend_buffer_i {
+ const char * (*GGML_CALL get_name) (ggml_backend_buffer_t buffer);
+ void (*GGML_CALL free_buffer)(ggml_backend_buffer_t buffer);
+ void * (*GGML_CALL get_base) (ggml_backend_buffer_t buffer);
+ void (*GGML_CALL init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+ void (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+ void (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
+ bool (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
+ void (*GGML_CALL clear) (ggml_backend_buffer_t buffer, uint8_t value);
+ void (*GGML_CALL reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
+ };
+
+ struct ggml_backend_buffer {
+ struct ggml_backend_buffer_i iface;
+ ggml_backend_buffer_type_t buft;
+ ggml_backend_buffer_context_t context;
+ size_t size;
+ enum ggml_backend_buffer_usage usage;
+ };
+
+ GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init(
+ ggml_backend_buffer_type_t buft,
+ struct ggml_backend_buffer_i iface,
+ ggml_backend_buffer_context_t context,
+ size_t size);
+
+ // do not use directly, use ggml_backend_tensor_copy instead
+ bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst);
+
+ // buffer that contains a collection of buffers
+ GGML_CALL ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers);
+ GGML_CALL bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer);
+ GGML_CALL void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
+
+ //
+ // Backend
+ //
+
+ typedef void * ggml_backend_context_t;
+
+ struct ggml_backend_i {
+ const char * (*GGML_CALL get_name)(ggml_backend_t backend);
+
+ void (*GGML_CALL free)(ggml_backend_t backend);
+
+ // buffer allocation
+ ggml_backend_buffer_type_t (*GGML_CALL get_default_buffer_type)(ggml_backend_t backend);
+
+ // (optional) asynchronous tensor data access
+ void (*GGML_CALL set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+ void (*GGML_CALL get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
+ bool (*GGML_CALL cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);
+
+ // (optional) complete all pending operations
+ void (*GGML_CALL synchronize)(ggml_backend_t backend);
+
+ // compute graph with a plan (not used currently)
+ // create a new plan for a graph
+ ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
+ void (*GGML_CALL graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+ // update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology
+ void (*GGML_CALL graph_plan_update) (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph);
+ // compute the graph with the plan
+ enum ggml_status (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+
+ // compute graph without a plan (async)
+ enum ggml_status (*GGML_CALL graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+
+ // check if the backend can compute an operation
+ bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
+
+ // check if the backend can use tensors allocated in a buffer type
+ bool (*GGML_CALL supports_buft)(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
+
+ // check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer
+ // these should be expensive operations with large batch sizes that may benefit from running on this backend
+ // even if the weight has to be copied from the CPU temporarily
+ bool (*GGML_CALL offload_op)(ggml_backend_t backend, const struct ggml_tensor * op);
+
+ // (optional) event synchronization
+ // create a new event that can record events on this backend instance
+ ggml_backend_event_t (*GGML_CALL event_new) (ggml_backend_t backend);
+ void (*GGML_CALL event_free) (ggml_backend_event_t event);
+ // record an event on the backend instance that created it
+ void (*GGML_CALL event_record) (ggml_backend_event_t event);
+ // wait for an event on on a different backend instance
+ void (*GGML_CALL event_wait) (ggml_backend_t backend, ggml_backend_event_t event);
+ // block until an event is recorded
+ void (*GGML_CALL event_synchronize) (ggml_backend_event_t event);
+ };
+
+ struct ggml_backend {
+ ggml_guid_t guid;
+
+ struct ggml_backend_i iface;
+ ggml_backend_context_t context;
+ };
+
+ struct ggml_backend_event {
+ ggml_backend_t backend;
+ void * context;
+ };
+
+ //
+ // Backend registry
+ //
+
+ typedef ggml_backend_t (*GGML_CALL ggml_backend_init_fn)(const char * params, void * user_data);
+
+ GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c
new file mode 100644
index 00000000..d39cfed8
--- /dev/null
+++ b/ggml/src/ggml-backend.c
@@ -0,0 +1,2234 @@
+#include "ggml-backend-impl.h"
+#include "ggml-alloc.h"
+#include "ggml-impl.h"
+
+#include <assert.h>
+#include <limits.h>
+#include <stdarg.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+// backend buffer type
+
+const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {
+ return buft->iface.get_name(buft);
+}
+
+GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ return buft->iface.alloc_buffer(buft, size);
+}
+
+size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) {
+ return buft->iface.get_alignment(buft);
+}
+
+size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
+ // get_max_size is optional, defaults to SIZE_MAX
+ if (buft->iface.get_max_size) {
+ return buft->iface.get_max_size(buft);
+ }
+ return SIZE_MAX;
+}
+
+GGML_CALL size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) {
+ // get_alloc_size is optional, defaults to ggml_nbytes
+ if (buft->iface.get_alloc_size) {
+ size_t size = buft->iface.get_alloc_size(buft, tensor);
+ assert(size >= ggml_nbytes(tensor));
+ return size;
+ }
+ return ggml_nbytes(tensor);
+}
+
+bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {
+ if (buft->iface.is_host) {
+ return buft->iface.is_host(buft);
+ }
+ return false;
+}
+
+// backend buffer
+
+GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init(
+ ggml_backend_buffer_type_t buft,
+ struct ggml_backend_buffer_i iface,
+ ggml_backend_buffer_context_t context,
+ size_t size) {
+ ggml_backend_buffer_t buffer = malloc(sizeof(struct ggml_backend_buffer));
+
+ (*buffer) = (struct ggml_backend_buffer) {
+ /* .interface = */ iface,
+ /* .buft = */ buft,
+ /* .context = */ context,
+ /* .size = */ size,
+ /* .usage = */ GGML_BACKEND_BUFFER_USAGE_ANY
+ };
+
+ return buffer;
+}
+
+const char * ggml_backend_buffer_name(ggml_backend_buffer_t buffer) {
+ return buffer->iface.get_name(buffer);
+}
+
+void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
+ if (buffer == NULL) {
+ return;
+ }
+
+ if (buffer->iface.free_buffer != NULL) {
+ buffer->iface.free_buffer(buffer);
+ }
+ free(buffer);
+}
+
+size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
+ return buffer->size;
+}
+
+void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
+ void * base = buffer->iface.get_base(buffer);
+
+ GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
+
+ return base;
+}
+
+GGML_CALL void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+ // init_tensor is optional
+ if (buffer->iface.init_tensor) {
+ buffer->iface.init_tensor(buffer, tensor);
+ }
+}
+
+size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer) {
+ return ggml_backend_buft_get_alignment(ggml_backend_buffer_get_type(buffer));
+}
+
+size_t ggml_backend_buffer_get_max_size(ggml_backend_buffer_t buffer) {
+ return ggml_backend_buft_get_max_size(ggml_backend_buffer_get_type(buffer));
+}
+
+size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+ return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor);
+}
+
+void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ buffer->iface.clear(buffer, value);
+}
+
+bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) {
+ return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));
+}
+
+void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
+ buffer->usage = usage;
+
+ // FIXME: add a generic callback to the buffer interface
+ if (ggml_backend_buffer_is_multi_buffer(buffer)) {
+ ggml_backend_multi_buffer_set_usage(buffer, usage);
+ }
+}
+
+enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) {
+ return buffer->usage;
+}
+
+ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) {
+ return buffer->buft;
+}
+
+void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) {
+ if (buffer->iface.reset) {
+ buffer->iface.reset(buffer);
+ }
+}
+
+bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst) {
+ ggml_backend_buffer_t dst_buf = dst->view_src ? dst->view_src->buffer : dst->buffer;
+ if (dst_buf->iface.cpy_tensor) {
+ return dst_buf->iface.cpy_tensor(dst_buf, src, dst);
+ }
+ return false;
+}
+
+// backend
+
+ggml_guid_t ggml_backend_guid(ggml_backend_t backend) {
+ if (backend == NULL) {
+ return NULL;
+ }
+ return backend->guid;
+}
+
+const char * ggml_backend_name(ggml_backend_t backend) {
+ if (backend == NULL) {
+ return "NULL";
+ }
+ return backend->iface.get_name(backend);
+}
+
+void ggml_backend_free(ggml_backend_t backend) {
+ if (backend == NULL) {
+ return;
+ }
+
+ backend->iface.free(backend);
+}
+
+ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
+ return backend->iface.get_default_buffer_type(backend);
+}
+
+ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) {
+ return ggml_backend_buft_alloc_buffer(ggml_backend_get_default_buffer_type(backend), size);
+}
+
+size_t ggml_backend_get_alignment(ggml_backend_t backend) {
+ return ggml_backend_buft_get_alignment(ggml_backend_get_default_buffer_type(backend));
+}
+
+size_t ggml_backend_get_max_size(ggml_backend_t backend) {
+ return ggml_backend_buft_get_max_size(ggml_backend_get_default_buffer_type(backend));
+}
+
+void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+
+ if (backend->iface.set_tensor_async == NULL) {
+ ggml_backend_tensor_set(tensor, data, offset, size);
+ } else {
+ backend->iface.set_tensor_async(backend, tensor, data, offset, size);
+ }
+}
+
+void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+
+ if (backend->iface.get_tensor_async == NULL) {
+ ggml_backend_tensor_get(tensor, data, offset, size);
+ } else {
+ backend->iface.get_tensor_async(backend, tensor, data, offset, size);
+ }
+}
+
+GGML_CALL void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+ GGML_ASSERT(buf != NULL && "tensor buffer not set");
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+
+ if (!size) {
+ return;
+ }
+
+ buf->iface.set_tensor(buf, tensor, data, offset, size);
+}
+
+GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+ GGML_ASSERT(buf != NULL && "tensor buffer not set");
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+
+ if (!size) {
+ return;
+ }
+
+ buf->iface.get_tensor(buf, tensor, data, offset, size);
+}
+
+void ggml_backend_synchronize(ggml_backend_t backend) {
+ if (backend->iface.synchronize == NULL) {
+ return;
+ }
+
+ backend->iface.synchronize(backend);
+}
+
+ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+ GGML_ASSERT(backend->iface.graph_plan_create != NULL);
+
+ return backend->iface.graph_plan_create(backend, cgraph);
+}
+
+void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+ GGML_ASSERT(backend->iface.graph_plan_free != NULL);
+
+ backend->iface.graph_plan_free(backend, plan);
+}
+
+enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+ GGML_ASSERT(backend->iface.graph_plan_compute != NULL);
+
+ return backend->iface.graph_plan_compute(backend, plan);
+}
+
+enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+ enum ggml_status err = ggml_backend_graph_compute_async(backend, cgraph);
+ ggml_backend_synchronize(backend);
+ return err;
+}
+
+enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+ return backend->iface.graph_compute(backend, cgraph);
+}
+
+bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+ return backend->iface.supports_op(backend, op);
+}
+
+bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+ return backend->iface.supports_buft(backend, buft);
+}
+
+bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+ if (backend->iface.offload_op != NULL) {
+ return backend->iface.offload_op(backend, op);
+ }
+ return false;
+}
+
+// backend copy
+
+static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
+ if (a->type != b->type) {
+ return false;
+ }
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ if (a->ne[i] != b->ne[i]) {
+ return false;
+ }
+ if (a->nb[i] != b->nb[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
+ GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
+
+ if (src == dst) {
+ return;
+ }
+
+ if (ggml_backend_buffer_is_host(src->buffer)) {
+ ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
+ } else if (ggml_backend_buffer_is_host(dst->buffer)) {
+ ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+ } else if (!ggml_backend_buffer_copy_tensor(src, dst)) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer));
+#endif
+ size_t nbytes = ggml_nbytes(src);
+ void * data = malloc(nbytes);
+ ggml_backend_tensor_get(src, data, 0, nbytes);
+ ggml_backend_tensor_set(dst, data, 0, nbytes);
+ free(data);
+ }
+}
+
+void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) {
+ GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
+
+ if (src == dst) {
+ return;
+ }
+
+ if (backend_dst->iface.cpy_tensor_async != NULL) {
+ if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) {
+ return;
+ }
+ }
+
+ // an async copy would normally happen after all the queued operations on both backends are completed
+ // sync src, set_async dst
+ if (ggml_backend_buffer_is_host(src->buffer)) {
+ ggml_backend_synchronize(backend_src);
+ ggml_backend_tensor_set_async(backend_dst, dst, src->data, 0, ggml_nbytes(src));
+ } else {
+ ggml_backend_synchronize(backend_src);
+ ggml_backend_tensor_copy(src, dst);
+ ggml_backend_synchronize(backend_dst);
+ }
+}
+
+// events
+
+ggml_backend_event_t ggml_backend_event_new(ggml_backend_t backend) {
+ if (backend->iface.event_new == NULL) {
+ return NULL;
+ }
+ return backend->iface.event_new(backend);
+}
+
+void ggml_backend_event_free(ggml_backend_event_t event) {
+ if (event == NULL) {
+ return;
+ }
+ event->backend->iface.event_free(event);
+}
+
+void ggml_backend_event_record(ggml_backend_event_t event) {
+ GGML_ASSERT(event->backend->iface.event_record != NULL);
+
+ event->backend->iface.event_record(event);
+}
+
+void ggml_backend_event_synchronize(ggml_backend_event_t event) {
+ GGML_ASSERT(event->backend->iface.event_synchronize != NULL);
+
+ event->backend->iface.event_synchronize(event);
+}
+
+void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+ GGML_ASSERT(backend->iface.event_wait != NULL);
+
+ backend->iface.event_wait(backend, event);
+}
+
+// backend registry
+
+#define GGML_REG_MAX_BACKENDS 64
+
+struct ggml_backend_reg {
+ char name[128];
+ ggml_backend_init_fn init_fn;
+ ggml_backend_buffer_type_t default_buffer_type;
+ void * user_data;
+};
+
+static struct ggml_backend_reg ggml_backend_registry[GGML_REG_MAX_BACKENDS];
+static size_t ggml_backend_registry_count = 0;
+
+GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data);
+
+GGML_CALL static void ggml_backend_registry_init(void) {
+ static bool initialized = false;
+
+ if (initialized) {
+ return;
+ }
+
+ initialized = true;
+
+ ggml_backend_register("CPU", ggml_backend_reg_cpu_init, ggml_backend_cpu_buffer_type(), NULL);
+
+ // add forward decls here to avoid including the backend headers
+#ifdef GGML_USE_CUDA
+ extern GGML_CALL void ggml_backend_cuda_reg_devices(void);
+ ggml_backend_cuda_reg_devices();
+#endif
+
+#ifdef GGML_USE_SYCL
+ extern void ggml_backend_sycl_reg_devices(void);
+ ggml_backend_sycl_reg_devices();
+#endif
+
+#ifdef GGML_USE_METAL
+ extern GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data);
+ extern GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
+ ggml_backend_register("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL);
+#endif
+
+#ifdef GGML_USE_VULKAN
+ extern GGML_CALL int ggml_backend_vk_reg_devices(void);
+ ggml_backend_vk_reg_devices();
+#endif
+
+#ifdef GGML_USE_KOMPUTE
+ extern GGML_CALL void ggml_backend_kompute_reg_devices(void);
+ ggml_backend_kompute_reg_devices();
+#endif
+
+#ifdef GGML_USE_CANN
+ extern GGML_CALL int ggml_backend_cann_reg_devices(void);
+ ggml_backend_cann_reg_devices();
+#endif
+}
+
+GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
+ GGML_ASSERT(ggml_backend_registry_count < GGML_REG_MAX_BACKENDS);
+
+ size_t id = ggml_backend_registry_count;
+
+ ggml_backend_registry[id] = (struct ggml_backend_reg) {
+ /* .name = */ {0},
+ /* .fn = */ init_fn,
+ /* .default_buffer_type = */ default_buffer_type,
+ /* .user_data = */ user_data,
+ };
+
+ snprintf(ggml_backend_registry[id].name, sizeof(ggml_backend_registry[id].name), "%s", name);
+
+#ifndef NDEBUG
+ fprintf(stderr, "%s: registered backend %s\n", __func__, name);
+#endif
+
+ ggml_backend_registry_count++;
+}
+
+size_t ggml_backend_reg_get_count(void) {
+ ggml_backend_registry_init();
+
+ return ggml_backend_registry_count;
+}
+
+size_t ggml_backend_reg_find_by_name(const char * name) {
+ ggml_backend_registry_init();
+
+ for (size_t i = 0; i < ggml_backend_registry_count; i++) {
+ // TODO: case insensitive in a portable way
+ if (strcmp(ggml_backend_registry[i].name, name) == 0) {
+ return i;
+ }
+ }
+
+ // not found
+ return SIZE_MAX;
+}
+
+// init from backend:params string
+ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str) {
+ ggml_backend_registry_init();
+
+ const char * params = strchr(backend_str, ':');
+ char backend_name[128];
+ if (params == NULL) {
+ snprintf(backend_name, sizeof(backend_name), "%s", backend_str);
+ params = "";
+ } else {
+ snprintf(backend_name, sizeof(backend_name), "%.*s", (int)(params - backend_str), backend_str);
+ params++;
+ }
+
+ size_t backend_i = ggml_backend_reg_find_by_name(backend_name);
+
+ if (backend_i == SIZE_MAX) {
+ fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
+ return NULL;
+ }
+
+ return ggml_backend_reg_init_backend(backend_i, params);
+}
+
+const char * ggml_backend_reg_get_name(size_t i) {
+ ggml_backend_registry_init();
+
+ GGML_ASSERT(i < ggml_backend_registry_count);
+ return ggml_backend_registry[i].name;
+}
+
+ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params) {
+ ggml_backend_registry_init();
+
+ GGML_ASSERT(i < ggml_backend_registry_count);
+ return ggml_backend_registry[i].init_fn(params, ggml_backend_registry[i].user_data);
+}
+
+ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i) {
+ ggml_backend_registry_init();
+
+ GGML_ASSERT(i < ggml_backend_registry_count);
+ return ggml_backend_registry[i].default_buffer_type;
+}
+
+ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size) {
+ ggml_backend_registry_init();
+
+ GGML_ASSERT(i < ggml_backend_registry_count);
+ return ggml_backend_buft_alloc_buffer(ggml_backend_registry[i].default_buffer_type, size);
+}
+
+// backend CPU
+
+static const size_t TENSOR_ALIGNMENT = 32; // required for mmap as gguf only guarantees 32-byte alignment
+
+GGML_CALL static const char * ggml_backend_cpu_buffer_name(ggml_backend_buffer_t buffer) {
+ return "CPU";
+
+ GGML_UNUSED(buffer);
+}
+
+GGML_CALL static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
+ uintptr_t data = (uintptr_t)buffer->context;
+
+ // align the buffer
+ if (data % TENSOR_ALIGNMENT != 0) {
+ data = GGML_PAD(data, TENSOR_ALIGNMENT);
+ }
+
+ return (void *)data;
+}
+
+GGML_CALL static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ free(buffer->context);
+}
+
+GGML_CALL static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ memcpy((char *)tensor->data + offset, data, size);
+
+ GGML_UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ memcpy(data, (const char *)tensor->data + offset, size);
+
+ GGML_UNUSED(buffer);
+}
+
+GGML_CALL static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+ if (ggml_backend_buffer_is_host(src->buffer)) {
+ memcpy(dst->data, src->data, ggml_nbytes(src));
+ return true;
+ }
+ return false;
+
+ GGML_UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ memset(buffer->context, value, buffer->size);
+}
+
+static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
+ /* .get_name = */ ggml_backend_cpu_buffer_name,
+ /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_cpu_buffer_get_base,
+ /* .init_tensor = */ NULL, // no initialization required
+ /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
+ /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
+ /* .clear = */ ggml_backend_cpu_buffer_clear,
+ /* .reset = */ NULL,
+};
+
+// for buffers from ptr, free is not called
+static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
+ /* .get_name = */ ggml_backend_cpu_buffer_name,
+ /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
+ /* .get_base = */ ggml_backend_cpu_buffer_get_base,
+ /* .init_tensor = */ NULL, // no initialization required
+ /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
+ /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
+ /* .clear = */ ggml_backend_cpu_buffer_clear,
+ /* .reset = */ NULL,
+};
+
+GGML_CALL static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+ return "CPU";
+
+ GGML_UNUSED(buft);
+}
+
+GGML_CALL static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
+ void * data = malloc(size); // TODO: use GGML_ALIGNED_MALLOC (move to ggml-impl.h)
+ if (data == NULL) {
+ fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
+ return NULL;
+ }
+
+ return ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size);
+}
+
+GGML_CALL static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+ return TENSOR_ALIGNMENT;
+
+ GGML_UNUSED(buft);
+}
+
+GGML_CALL static bool ggml_backend_cpu_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+ return true;
+
+ GGML_UNUSED(buft);
+}
+
+GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {
+ static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = {
+ /* .iface = */ {
+ /* .get_name = */ ggml_backend_cpu_buffer_type_get_name,
+ /* .alloc_buffer = */ ggml_backend_cpu_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
+ /* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
+ },
+ /* .context = */ NULL,
+ };
+
+ return &ggml_backend_cpu_buffer_type;
+}
+
+#ifdef GGML_USE_CPU_HBM
+
+// buffer type HBM
+
+#include <hbwmalloc.h>
+
+GGML_CALL static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+ return "CPU_HBM";
+
+ GGML_UNUSED(buft);
+}
+
+GGML_CALL static const char * ggml_backend_cpu_hbm_buffer_get_name(ggml_backend_buffer_t buf) {
+ return "CPU_HBM";
+
+ GGML_UNUSED(buf);
+}
+
+GGML_CALL static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ hbw_free(buffer->context);
+}
+
+GGML_CALL static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ //void * ptr = hbw_malloc(size);
+ void * ptr;
+ int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
+ if (result != 0) {
+ fprintf(stderr, "failed to allocate HBM buffer of size %zu\n", size);
+ return NULL;
+ }
+
+ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
+ buffer->buft = buft;
+ buffer->iface.get_name = ggml_backend_cpu_hbm_buffer_get_name;
+ buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
+
+ return buffer;
+}
+
+ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
+ static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = {
+ /* .iface = */ {
+ /* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name,
+ /* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
+ /* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
+ },
+ /* .context = */ NULL,
+ };
+
+ return &ggml_backend_cpu_buffer_type_hbm;
+}
+#endif
+
+struct ggml_backend_cpu_context {
+ int n_threads;
+ void * work_data;
+ size_t work_size;
+
+ ggml_abort_callback abort_callback;
+ void * abort_callback_data;
+};
+
+GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
+ return "CPU";
+
+ GGML_UNUSED(backend);
+}
+
+GGML_CALL static void ggml_backend_cpu_free(ggml_backend_t backend) {
+ struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+ free(cpu_ctx->work_data);
+ free(cpu_ctx);
+ free(backend);
+}
+
+GGML_CALL static ggml_backend_buffer_type_t ggml_backend_cpu_get_default_buffer_type(ggml_backend_t backend) {
+ return ggml_backend_cpu_buffer_type();
+
+ GGML_UNUSED(backend);
+}
+
+struct ggml_backend_plan_cpu {
+ struct ggml_cplan cplan;
+ struct ggml_cgraph cgraph;
+};
+
+GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
+ struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+
+ struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu));
+
+ cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
+ cpu_plan->cgraph = *cgraph; // FIXME: deep copy
+
+ if (cpu_plan->cplan.work_size > 0) {
+ cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
+ if (cpu_plan->cplan.work_data == NULL) {
+ free(cpu_plan);
+ return NULL;
+ }
+ }
+
+ cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback;
+ cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+
+ return cpu_plan;
+}
+
+GGML_CALL static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+ struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
+
+ free(cpu_plan->cplan.work_data);
+ free(cpu_plan);
+
+ GGML_UNUSED(backend);
+}
+
+GGML_CALL static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+ struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
+
+ return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
+
+ GGML_UNUSED(backend);
+}
+
+GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+ struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+
+ struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
+
+ if (cpu_ctx->work_size < cplan.work_size) {
+ free(cpu_ctx->work_data);
+ cpu_ctx->work_data = malloc(cplan.work_size);
+ if (cpu_ctx->work_data == NULL) {
+ cpu_ctx->work_size = 0;
+ return GGML_STATUS_ALLOC_FAILED;
+ }
+ cpu_ctx->work_size = cplan.work_size;
+ }
+ cplan.work_data = cpu_ctx->work_data;
+
+ cplan.abort_callback = cpu_ctx->abort_callback;
+ cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+
+ return ggml_graph_compute(cgraph, &cplan);
+}
+
+GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+ switch (op->op) {
+ case GGML_OP_CPY:
+ return
+ op->type != GGML_TYPE_IQ2_XXS &&
+ op->type != GGML_TYPE_IQ2_XS &&
+ op->type != GGML_TYPE_IQ1_S &&
+ op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
+ case GGML_OP_MUL_MAT:
+ return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
+ default:
+ return true;
+ }
+
+ GGML_UNUSED(backend);
+}
+
+GGML_CALL static bool ggml_backend_cpu_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+ return ggml_backend_buft_is_host(buft);
+
+ GGML_UNUSED(backend);
+}
+
+static struct ggml_backend_i cpu_backend_i = {
+ /* .get_name = */ ggml_backend_cpu_name,
+ /* .free = */ ggml_backend_cpu_free,
+ /* .get_default_buffer_type = */ ggml_backend_cpu_get_default_buffer_type,
+ /* .set_tensor_async = */ NULL,
+ /* .get_tensor_async = */ NULL,
+ /* .cpy_tensor_async = */ NULL,
+ /* .synchronize = */ NULL,
+ /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,
+ /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free,
+ /* .graph_plan_update = */ NULL,
+ /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute,
+ /* .graph_compute = */ ggml_backend_cpu_graph_compute,
+ /* .supports_op = */ ggml_backend_cpu_supports_op,
+ /* .supports_buft = */ ggml_backend_cpu_supports_buft,
+ /* .offload_op = */ NULL,
+ /* .event_new = */ NULL,
+ /* .event_free = */ NULL,
+ /* .event_record = */ NULL,
+ /* .event_wait = */ NULL,
+ /* .event_synchronize = */ NULL,
+};
+
+static ggml_guid_t ggml_backend_cpu_guid(void) {
+ static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 };
+ return &guid;
+}
+
+ggml_backend_t ggml_backend_cpu_init(void) {
+ struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
+ if (ctx == NULL) {
+ return NULL;
+ }
+
+ ctx->n_threads = GGML_DEFAULT_N_THREADS;
+ ctx->work_data = NULL;
+ ctx->work_size = 0;
+ ctx->abort_callback = NULL;
+ ctx->abort_callback_data = NULL;
+
+ ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));
+ if (cpu_backend == NULL) {
+ free(ctx);
+ return NULL;
+ }
+
+ *cpu_backend = (struct ggml_backend) {
+ /* .guid = */ ggml_backend_cpu_guid(),
+ /* .interface = */ cpu_backend_i,
+ /* .context = */ ctx
+ };
+ return cpu_backend;
+}
+
+GGML_CALL bool ggml_backend_is_cpu(ggml_backend_t backend) {
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid());
+}
+
+void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
+ GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+ struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+ ctx->n_threads = n_threads;
+}
+
+void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {
+ GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+ struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+ ctx->abort_callback = abort_callback;
+ ctx->abort_callback_data = abort_callback_data;
+}
+
+GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
+ GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned");
+ return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
+}
+
+GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data) {
+ return ggml_backend_cpu_init();
+
+ GGML_UNUSED(params);
+ GGML_UNUSED(user_data);
+}
+
+// multi-buffer buffer
+
+struct ggml_backend_multi_buffer_context {
+ ggml_backend_buffer_t * buffers;
+ size_t n_buffers;
+};
+
+typedef struct ggml_backend_multi_buffer_context * ggml_backend_multi_buffer_context_t;
+
+GGML_CALL static const char * ggml_backend_multi_buffer_get_name(ggml_backend_buffer_t buffer) {
+ ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context;
+
+ return ctx->buffers[0]->iface.get_name(ctx->buffers[0]);
+}
+
+GGML_CALL static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context;
+ for (size_t i = 0; i < ctx->n_buffers; i++) {
+ ggml_backend_buffer_free(ctx->buffers[i]);
+ }
+
+ free(ctx->buffers);
+ free(ctx);
+}
+
+GGML_CALL static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context;
+ for (size_t i = 0; i < ctx->n_buffers; i++) {
+ ggml_backend_buffer_clear(ctx->buffers[i], value);
+ }
+}
+
+static struct ggml_backend_buffer_i ggml_backend_multi_buffer_context_interface(void) {
+ static struct ggml_backend_buffer_i multi_backend_buffer_i = {
+ /* .get_name = */ ggml_backend_multi_buffer_get_name,
+ /* .free_buffer = */ ggml_backend_multi_buffer_free_buffer,
+ /* .get_base = */ NULL,
+ /* .init_tensor = */ NULL,
+ /* .set_tensor = */ NULL,
+ /* .get_tensor = */ NULL,
+ /* .cpy_tensor = */ NULL,
+ /* .clear = */ ggml_backend_multi_buffer_clear,
+ /* .reset = */ NULL,
+ };
+
+ return multi_backend_buffer_i;
+}
+
+GGML_CALL ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers) {
+ ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) malloc(sizeof(struct ggml_backend_multi_buffer_context));
+ ctx->n_buffers = n_buffers;
+ ctx->buffers = (ggml_backend_buffer_t *) malloc(n_buffers * sizeof(ggml_backend_buffer_t));
+
+ GGML_ASSERT(ctx->buffers != NULL);
+
+ size_t total_size = 0;
+ for (size_t i = 0; i < n_buffers; i++) {
+ ctx->buffers[i] = buffers[i];
+ total_size += ggml_backend_buffer_get_size(buffers[i]);
+ }
+
+ return ggml_backend_buffer_init(buffers[0]->buft, ggml_backend_multi_buffer_context_interface(), ctx, total_size);
+}
+
+GGML_CALL bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) {
+ return buffer->iface.get_name == ggml_backend_multi_buffer_get_name;
+}
+
+GGML_CALL void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
+ GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer));
+ ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context;
+ for (size_t i = 0; i < ctx->n_buffers; i++) {
+ ggml_backend_buffer_set_usage(ctx->buffers[i], usage);
+ }
+}
+
+// creates a copy of the tensor with the same memory layout
+static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) {
+ struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor);
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ dup->nb[i] = tensor->nb[i];
+ }
+ return dup;
+}
+
+static bool ggml_is_view_op(enum ggml_op op) {
+ return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
+}
+
+// scheduler
+
+#ifndef GGML_SCHED_MAX_BACKENDS
+#define GGML_SCHED_MAX_BACKENDS 16
+#endif
+
+#ifndef GGML_SCHED_MAX_SPLITS
+#define GGML_SCHED_MAX_SPLITS 2048
+#endif
+
+#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
+#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC
+#endif
+
+#ifndef GGML_SCHED_MAX_COPIES
+#define GGML_SCHED_MAX_COPIES 4
+#endif
+
+struct ggml_backend_sched_split {
+ int backend_id;
+ int i_start;
+ int i_end;
+ struct ggml_tensor * inputs[GGML_SCHED_MAX_SPLIT_INPUTS];
+ int n_inputs;
+ // graph view of this split
+ struct ggml_cgraph graph;
+};
+
+struct ggml_backend_sched {
+ bool is_reset; // true if the scheduler has been reset since the last graph split
+ bool is_alloc;
+
+ int n_backends;
+
+ ggml_backend_t backends[GGML_SCHED_MAX_BACKENDS];
+ ggml_backend_buffer_type_t bufts[GGML_SCHED_MAX_BACKENDS];
+ ggml_gallocr_t galloc;
+
+ // hash keys of the nodes in the graph
+ struct ggml_hash_set hash_set;
+ // hash values
+ int * tensor_backend_id;
+ struct ggml_tensor * (* tensor_copies)[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES];
+
+ int * node_backend_ids; // [graph_size]
+ int * leaf_backend_ids; // [graph_size]
+
+ int * prev_node_backend_ids; // [graph_size]
+ int * prev_leaf_backend_ids; // [graph_size]
+
+ // copy of the graph with modified inputs
+ struct ggml_cgraph * graph;
+
+ // graph splits
+ struct ggml_backend_sched_split * splits;
+ int n_splits;
+ int splits_capacity;
+
+ // pipeline parallelism support
+ int n_copies;
+ int cur_copy;
+ ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES];
+ struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS];
+ int n_graph_inputs;
+
+ struct ggml_context * ctx;
+
+ ggml_backend_sched_eval_callback callback_eval;
+ void * callback_eval_user_data;
+
+ bool debug;
+
+ // align context_buffer to GGML_MEM_ALIGN
+#ifdef _MSC_VER
+ __declspec(align(GGML_MEM_ALIGN))
+#else
+ __attribute__((aligned(GGML_MEM_ALIGN)))
+#endif
+ char context_buffer[GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
+};
+
+#define hash_id(tensor) ggml_hash_find_or_insert(sched->hash_set, tensor)
+#define tensor_backend_id(tensor) sched->tensor_backend_id[hash_id(tensor)]
+
+// returns the priority of the backend, lower id is higher priority
+static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backend_t backend) {
+ for (int i = 0; i < sched->n_backends; i++) {
+ if (sched->backends[i] == backend) {
+ return i;
+ }
+ }
+ return -1;
+}
+
+static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) {
+ ggml_backend_buffer_t buffer = tensor->buffer;
+ if (buffer == NULL) {
+ return -1;
+ }
+
+ // find highest prio backend that supports the buffer type and the op
+ for (int i = 0; i < sched->n_backends; i++) {
+ if (ggml_backend_supports_buft(sched->backends[i], buffer->buft) &&
+ ggml_backend_supports_op(sched->backends[i], op)) {
+ return i;
+ }
+ }
+
+#ifndef NDEBUG
+ fprintf(stderr, "%s: warning: no backend supports op %s with a weight with buffer type %s used in tensor %s, the weight will need to be copied\n",
+ __func__, ggml_op_desc(tensor), ggml_backend_buffer_name(buffer), tensor->name);
+#endif
+
+ return -1;
+}
+
+#if 0
+static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only
+#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
+#define GET_CAUSE(node) causes[hash_id(node)]
+#else
+#define SET_CAUSE(node, ...)
+#define GET_CAUSE(node) ""
+#endif
+
+// returns the backend that should be used for the node based on the current locations
+static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) {
+ // TODO: use supports_op to check if the backend supports the op
+
+ // assign pre-allocated nodes to their backend
+ int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor);
+ if (cur_backend_id != -1) {
+ SET_CAUSE(tensor, "1.dst");
+ return cur_backend_id;
+ }
+
+ // view_src
+ if (tensor->view_src != NULL) {
+ cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src, tensor);
+ if (cur_backend_id != -1) {
+ SET_CAUSE(tensor, "1.vsrc");
+ return cur_backend_id;
+ }
+ }
+
+ // graph input
+ if (tensor->flags & GGML_TENSOR_FLAG_INPUT) {
+ cur_backend_id = sched->n_backends - 1; // last backend (assumed CPU)
+ SET_CAUSE(tensor, "1.inp");
+ return cur_backend_id;
+ }
+
+ // assign nodes that use weights to the backend of the weights
+ // operations with weights are preferably run on the same backend as the weights
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ const struct ggml_tensor * src = tensor->src[i];
+ if (src == NULL) {
+ continue;
+ }
+ if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
+ int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
+ // check if a backend with higher prio wants to offload the op
+ if (src_backend_id == sched->n_backends - 1) {
+ for (int b = 0; b < src_backend_id; b++) {
+ if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
+ SET_CAUSE(tensor, "1.off");
+ return b;
+ }
+ }
+ }
+ SET_CAUSE(tensor, "1.wgt%d", i);
+ return src_backend_id;
+ }
+ }
+
+ return -1;
+}
+
+static char * fmt_size(size_t size) {
+ static char buffer[128];
+ if (size >= 1024*1024) {
+ snprintf(buffer, sizeof(buffer), "%zuM", size/1024/1024);
+ } else {
+ snprintf(buffer, sizeof(buffer), "%zuK", size/1024);
+ }
+ return buffer;
+}
+
+static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+ int cur_split = 0;
+ for (int i = 0; i < graph->n_nodes; i++) {
+ if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
+ ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id];
+ fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend),
+ sched->splits[cur_split].n_inputs);
+ for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
+ fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
+ fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
+ }
+ fprintf(stderr, "\n");
+ cur_split++;
+ }
+ struct ggml_tensor * node = graph->nodes[i];
+ if (ggml_is_view_op(node->op)) {
+ continue;
+ }
+ ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
+ fprintf(stderr, "node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
+ fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ continue;
+ }
+ ggml_backend_t src_backend = ggml_backend_sched_get_tensor_backend(sched, src);
+ fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name,
+ fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
+ }
+ fprintf(stderr, "\n");
+ }
+}
+
+static bool ggml_backend_sched_buffer_supported(ggml_backend_sched_t sched, struct ggml_tensor * t, int backend_id) {
+ ggml_backend_buffer_t buf = t->view_src ? t->view_src->buffer : t->buffer;
+ ggml_backend_buffer_type_t buft = NULL;
+
+ if (buf) {
+ // the tensor is already allocated
+ buft = buf->buft;
+ } else {
+ // see if the tensor already has a backend assigned, and use the buffer type of that backend
+ int tensor_backend_id = tensor_backend_id(t);
+ if (tensor_backend_id == -1 && t->view_src) {
+ tensor_backend_id = tensor_backend_id(t->view_src);
+ }
+ if (tensor_backend_id != -1) {
+ buft = sched->bufts[tensor_backend_id];
+ }
+ }
+
+ return buft != NULL && ggml_backend_supports_buft(sched->backends[backend_id], buft);
+}
+
+static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, struct ggml_tensor * node, int cur_backend_id, int * node_backend_id) {
+ if (ggml_backend_supports_op(sched->backends[cur_backend_id], node)) {
+ *node_backend_id = cur_backend_id;
+ SET_CAUSE(node, "2.sup");
+ }
+}
+
+// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
+static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+ // reset splits
+ sched->n_splits = 0;
+ sched->n_graph_inputs = 0;
+ sched->is_reset = false;
+
+ struct ggml_init_params params = {
+ /* .mem_size = */ sizeof(sched->context_buffer),
+ /* .mem_buffer = */ sched->context_buffer,
+ /* .no_alloc = */ true
+ };
+
+ ggml_free(sched->ctx);
+
+ sched->ctx = ggml_init(params);
+ if (sched->ctx == NULL) {
+ fprintf(stderr, "%s: failed to initialize context\n", __func__);
+ GGML_ASSERT(false);
+ }
+
+ // pass 1: assign backends to ops with pre-allocated inputs
+ for (int i = 0; i < graph->n_leafs; i++) {
+ struct ggml_tensor * leaf = graph->leafs[i];
+ int * leaf_backend_id = &tensor_backend_id(leaf);
+ if (*leaf_backend_id != -1) {
+ // do not overwrite user assignments
+ continue;
+ }
+ *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf);
+ }
+
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ int * node_backend_id = &tensor_backend_id(node);
+ if (*node_backend_id != -1) {
+ // do not overwrite user assignments
+ continue;
+ }
+ *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node);
+ // src
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ continue;
+ }
+ int * src_backend_id = &tensor_backend_id(src);
+ if (*src_backend_id == -1) {
+ *src_backend_id = ggml_backend_sched_backend_id_from_cur(sched, src);
+ }
+ }
+ }
+
+ // pass 2: expand current backend assignments
+ // assign the same backend to adjacent nodes
+ // expand gpu backends (i.e. non last prio) up and down, ignoring cpu (the lowest priority backend)
+ // thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops
+ // ops unsupported by the backend being expanded will be left unassigned so that they can be assigned later when the locations of its inputs are known
+ // expand gpu down
+ {
+ int cur_backend_id = -1;
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ if (ggml_is_view_op(node->op)) {
+ continue;
+ }
+ int * node_backend_id = &tensor_backend_id(node);
+ if (*node_backend_id != -1) {
+ if (*node_backend_id == sched->n_backends - 1) {
+ // skip cpu (lowest prio backend)
+ cur_backend_id = -1;
+ } else {
+ cur_backend_id = *node_backend_id;
+ }
+ } else if (cur_backend_id != -1) {
+ ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
+ }
+ }
+ }
+ // expand gpu up
+ {
+ int cur_backend_id = -1;
+ for (int i = graph->n_nodes - 1; i >= 0; i--) {
+ struct ggml_tensor * node = graph->nodes[i];
+ if (ggml_is_view_op(node->op)) {
+ continue;
+ }
+ int * node_backend_id = &tensor_backend_id(node);
+ if (*node_backend_id != -1) {
+ if (*node_backend_id == sched->n_backends - 1) {
+ // skip cpu (lowest prio backend)
+ cur_backend_id = -1;
+ } else {
+ cur_backend_id = *node_backend_id;
+ }
+ } else if (cur_backend_id != -1) {
+ ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
+ }
+ }
+ }
+ // expand rest down
+ {
+ int cur_backend_id = -1;
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ if (ggml_is_view_op(node->op)) {
+ continue;
+ }
+ int * node_backend_id = &tensor_backend_id(node);
+ if (*node_backend_id != -1) {
+ cur_backend_id = *node_backend_id;
+ } else if (cur_backend_id != -1) {
+ ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
+ }
+ }
+ }
+ // expand rest up
+ {
+ int cur_backend_id = -1;
+ for (int i = graph->n_nodes - 1; i >= 0; i--) {
+ struct ggml_tensor * node = graph->nodes[i];
+ if (ggml_is_view_op(node->op)) {
+ continue;
+ }
+ int * node_backend_id = &tensor_backend_id(node);
+ if (*node_backend_id != -1) {
+ cur_backend_id = *node_backend_id;
+ } else if (cur_backend_id != -1) {
+ ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
+ }
+ }
+ }
+
+ // pass 3: upgrade nodes to higher prio backends with compatible buffer types
+ // if the tensor is already in the same buffer type (*) as another higher priority backend, we should move it there
+ // however, we also need to verify that the sources are in compatible buffer types
+ // (*) the actual requirement is more relaxed, the buffer type of the backend should be supported by all the users of this tensor further down the graph
+ // however, this is slow to verify, so we have a more strict requirement that the buffer type is the same
+ // this is not uncommon since multiple backends can use host memory, with the same buffer type (eg. BLAS and CPU)
+ // additionally, set remaining unassigned nodes to the backend with the most supported inputs
+ // only nodes that could not be assigned during expansion due to the backend not supporting the op should be unassigned at this point
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ if (ggml_is_view_op(node->op)) {
+ continue;
+ }
+ int * node_backend_id = &tensor_backend_id(node);
+ if (*node_backend_id == -1) {
+ // unassigned node: find the backend with the most supported inputs
+ int n_supported_best = -1;
+ for (int b = 0; b < sched->n_backends; b++) {
+ if (ggml_backend_supports_op(sched->backends[b], node)) {
+ int n_supported = 0;
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ continue;
+ }
+ if ((tensor_backend_id(src) != -1 || tensor_backend_id(src->view_src) != -1) && ggml_backend_sched_buffer_supported(sched, src, b)) {
+ n_supported++;
+ }
+ }
+ if (n_supported > n_supported_best) {
+ n_supported_best = n_supported;
+ *node_backend_id = b;
+ SET_CAUSE(node, "3.best");
+ }
+ }
+ }
+ } else {
+ // assigned node: upgrade to higher prio backend if possible
+ for (int b = 0; b < *node_backend_id; b++) {
+ if (sched->bufts[b] == sched->bufts[*node_backend_id] && ggml_backend_supports_op(sched->backends[b], node)) {
+ bool supported = true;
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ continue;
+ }
+ if (!ggml_backend_sched_buffer_supported(sched, src, b)) {
+ supported = false;
+ break;
+ }
+ }
+ if (supported) {
+ *node_backend_id = b;
+ SET_CAUSE(node, "3.upg");
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ // pass 4: assign backends to remaining src from dst and view_src
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ int * cur_backend_id = &tensor_backend_id(node);
+ if (node->view_src != NULL && *cur_backend_id == -1) {
+ *cur_backend_id = tensor_backend_id(node->view_src);
+ SET_CAUSE(node, "4.vsrc");
+ }
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ continue;
+ }
+ int * src_backend_id = &tensor_backend_id(src);
+ if (*src_backend_id == -1) {
+ if (src->view_src != NULL) {
+ // views are always on the same backend as the source
+ *src_backend_id = tensor_backend_id(src->view_src);
+ SET_CAUSE(src, "4.vsrc");
+ } else {
+ *src_backend_id = *cur_backend_id;
+ SET_CAUSE(src, "4.cur");
+ }
+ }
+ }
+ }
+
+ // pass 4: split graph, find tensors that need to be copied
+ {
+ int i_split = 0;
+ struct ggml_backend_sched_split * split = &sched->splits[0];
+ // find the backend of the first split, skipping view ops
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ if (!ggml_is_view_op(node->op)) {
+ split->backend_id = tensor_backend_id(node);
+ break;
+ }
+ }
+ split->i_start = 0;
+ split->n_inputs = 0;
+ memset(split->inputs, 0, sizeof(split->inputs)); //HACK
+ int cur_backend_id = split->backend_id;
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+
+ if (ggml_is_view_op(node->op)) {
+ continue;
+ }
+
+ const int node_backend_id = tensor_backend_id(node);
+
+ GGML_ASSERT(node_backend_id != -1); // all nodes should be assigned by now
+
+ // check if we should start a new split based on the sources of the current node
+ bool need_new_split = false;
+ if (node_backend_id == cur_backend_id && split->n_inputs > 0) {
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ continue;
+ }
+ // check if a weight is on a different backend
+ // by starting a new split, the memory of the previously offloaded weights can be reused
+ if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
+ int src_backend_id = tensor_backend_id(src);
+ if (src_backend_id != -1 && src_backend_id != cur_backend_id) {
+ need_new_split = true;
+ break;
+ }
+ }
+ // check if the split has too many inputs
+ // FIXME: count the number of inputs instead of only checking when full
+ if (split->n_inputs == GGML_SCHED_MAX_SPLIT_INPUTS) {
+ const size_t id = hash_id(src);
+ int src_backend_id = sched->tensor_backend_id[id];
+ bool supported = ggml_backend_sched_buffer_supported(sched, src, cur_backend_id);
+ if (src_backend_id != cur_backend_id && sched->tensor_copies[hash_id(src)][cur_backend_id][0] == NULL && !supported) {
+ //printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name);
+ need_new_split = true;
+ break;
+ }
+ }
+ }
+ }
+
+ if (node_backend_id != cur_backend_id || need_new_split) {
+ split->i_end = i;
+ i_split++;
+ if (i_split >= sched->splits_capacity) {
+ sched->splits_capacity *= 2;
+ sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split));
+ GGML_ASSERT(sched->splits != NULL);
+ }
+ GGML_ASSERT(i_split < GGML_SCHED_MAX_SPLITS);
+ split = &sched->splits[i_split];
+ split->backend_id = node_backend_id;
+ split->i_start = i;
+ split->n_inputs = 0;
+ cur_backend_id = node_backend_id;
+ }
+
+ // find inputs that are not on the same backend
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ continue;
+ }
+
+ const int src_backend_id = tensor_backend_id(src);
+ assert(src_backend_id != -1); // all inputs should be assigned by now
+
+ if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) {
+ size_t id = hash_id(src);
+ if (sched->tensor_copies[id][src_backend_id][0] == NULL) {
+ ggml_backend_t backend = sched->backends[src_backend_id];
+ for (int c = 0; c < sched->n_copies; c++) {
+ struct ggml_tensor * tensor_copy;
+ if (c == sched->cur_copy) {
+ tensor_copy = src; // use the original tensor as the current copy
+ } else {
+ tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
+ ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c);
+ }
+ if (sched->n_copies > 1) {
+ ggml_set_input(tensor_copy);
+ ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor
+ }
+ sched->tensor_copies[id][src_backend_id][c] = tensor_copy;
+ SET_CAUSE(tensor_copy, "4.cpy");
+ }
+ int n_graph_inputs = sched->n_graph_inputs++;
+ GGML_ASSERT(n_graph_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);
+ sched->graph_inputs[n_graph_inputs] = src;
+ }
+ }
+
+ bool supported = ggml_backend_sched_buffer_supported(sched, src, cur_backend_id);
+ if (src_backend_id != cur_backend_id && !supported) {
+ // create a copy of the input in the split's backend
+ const size_t id = hash_id(src);
+ if (sched->tensor_copies[id][cur_backend_id][0] == NULL) {
+ ggml_backend_t backend = sched->backends[cur_backend_id];
+ for (int c = 0; c < sched->n_copies; c++) {
+ struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
+ ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c);
+ if (sched->n_copies > 1) {
+ ggml_set_input(tensor_copy);
+ ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor
+ }
+ sched->tensor_copies[id][cur_backend_id][c] = tensor_copy;
+ SET_CAUSE(tensor_copy, "4.cpy");
+ }
+ int n_inputs = split->n_inputs++;
+ GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);
+ split->inputs[n_inputs] = src;
+ }
+ node->src[j] = sched->tensor_copies[id][cur_backend_id][sched->cur_copy];
+ }
+ }
+ }
+ split->i_end = graph->n_nodes;
+ sched->n_splits = i_split + 1;
+ }
+
+ if (sched->debug) {
+ ggml_backend_sched_print_assignments(sched, graph);
+ }
+
+ // swap node_backend_ids and leaf_backend_ids and prevs
+ {
+ int * tmp = sched->node_backend_ids;
+ sched->node_backend_ids = sched->prev_node_backend_ids;
+ sched->prev_node_backend_ids = tmp;
+
+ tmp = sched->leaf_backend_ids;
+ sched->leaf_backend_ids = sched->prev_leaf_backend_ids;
+ sched->prev_leaf_backend_ids = tmp;
+ }
+
+ // create copies of the graph for each split
+ // TODO: avoid this copy
+ struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2, false);
+ for (int i = 0; i < sched->n_splits; i++) {
+ struct ggml_backend_sched_split * split = &sched->splits[i];
+ split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
+
+ // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
+ for (int j = 0; j < split->n_inputs; j++) {
+ assert(graph_copy->size > (graph_copy->n_nodes + 1));
+
+ struct ggml_tensor * input = split->inputs[j];
+ const size_t input_id = hash_id(input);
+ struct ggml_tensor * input_cpy = sched->tensor_copies[input_id][split->backend_id][sched->cur_copy];
+
+ // add a dependency to the input source so that it is not freed before the copy is done
+ struct ggml_tensor * input_dep = ggml_view_tensor(sched->ctx, input);
+ input_dep->src[0] = input;
+ sched->node_backend_ids[graph_copy->n_nodes] = sched->tensor_backend_id[input_id];
+ graph_copy->nodes[graph_copy->n_nodes++] = input_dep;
+
+ // add a dependency to the input copy so that it is allocated at the start of the split
+ sched->node_backend_ids[graph_copy->n_nodes] = split->backend_id;
+ graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;
+ }
+
+ for (int j = split->i_start; j < split->i_end; j++) {
+ assert(graph_copy->size > graph_copy->n_nodes);
+ sched->node_backend_ids[graph_copy->n_nodes] = tensor_backend_id(graph->nodes[j]);
+ graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];
+ }
+ }
+
+ if (sched->n_copies > 1) {
+ // add input copies as leafs so that they are allocated first
+ for (int i = 0; i < sched->n_graph_inputs; i++) {
+ struct ggml_tensor * input = sched->graph_inputs[i];
+ size_t id = hash_id(input);
+ int backend_id = tensor_backend_id(input);
+ for (int c = 0; c < sched->n_copies; c++) {
+ struct ggml_tensor * input_cpy = sched->tensor_copies[id][backend_id][c];
+ sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
+ graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
+ }
+ }
+
+ for (int i = 0; i < sched->n_splits; i++) {
+ struct ggml_backend_sched_split * split = &sched->splits[i];
+ int backend_id = split->backend_id;
+ for (int j = 0; j < split->n_inputs; j++) {
+ struct ggml_tensor * input = split->inputs[j];
+ size_t id = hash_id(input);
+ for (int c = 0; c < sched->n_copies; c++) {
+ struct ggml_tensor * input_cpy = sched->tensor_copies[id][backend_id][c];
+ sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
+ graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
+ }
+ }
+ }
+ }
+
+ // add leafs from the original graph
+ for (int i = 0; i < graph->n_leafs; i++) {
+ struct ggml_tensor * leaf = graph->leafs[i];
+ sched->leaf_backend_ids[graph_copy->n_leafs] = tensor_backend_id(leaf);
+ graph_copy->leafs[graph_copy->n_leafs++] = leaf;
+ }
+
+ sched->graph = graph_copy;
+}
+
+static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
+ bool backend_ids_changed = false;
+ for (int i = 0; i < sched->graph->n_nodes; i++) {
+ if (sched->node_backend_ids[i] != sched->prev_node_backend_ids[i] &&
+ sched->bufts[sched->node_backend_ids[i]] != sched->bufts[sched->prev_node_backend_ids[i]]) {
+ backend_ids_changed = true;
+ break;
+ }
+ }
+ if (!backend_ids_changed) {
+ for (int i = 0; i < sched->graph->n_leafs; i++) {
+ if (sched->leaf_backend_ids[i] != sched->prev_leaf_backend_ids[i] &&
+ sched->bufts[sched->leaf_backend_ids[i]] != sched->bufts[sched->prev_leaf_backend_ids[i]]) {
+ backend_ids_changed = true;
+ break;
+ }
+ }
+ }
+
+ // allocate graph
+ if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) {
+ // the re-allocation may cause the split inputs to be moved to a different address
+ ggml_backend_sched_synchronize(sched);
+#ifndef NDEBUG
+ fprintf(stderr, "%s: failed to allocate graph, reserving\n", __func__);
+#endif
+ ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);
+ if (!ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) {
+ fprintf(stderr, "%s: failed to allocate graph\n", __func__);
+ return false;
+ }
+ }
+
+ return true;
+}
+
+static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
+ struct ggml_backend_sched_split * splits = sched->splits;
+
+ for (int i = 0; i < sched->n_splits; i++) {
+ struct ggml_backend_sched_split * split = &splits[i];
+ int split_backend_id = split->backend_id;
+ ggml_backend_t split_backend = sched->backends[split_backend_id];
+
+ // copy the input tensors to the split backend
+ for (int j = 0; j < split->n_inputs; j++) {
+ ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]);
+ struct ggml_tensor * input = split->inputs[j];
+ struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split_backend_id][sched->cur_copy];
+
+ if (input->flags & GGML_TENSOR_FLAG_INPUT) {
+ // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
+ if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
+ ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
+ } else {
+ ggml_backend_synchronize(split_backend);
+ }
+ ggml_backend_tensor_copy(input, input_cpy);
+ } else {
+ // wait for the split backend to finish using the input before overwriting it
+ if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
+ ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
+ } else {
+ ggml_backend_synchronize(split_backend);
+ }
+ ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy);
+ }
+ }
+
+ if (!sched->callback_eval) {
+ enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
+ if (ec != GGML_STATUS_SUCCESS) {
+ return ec;
+ }
+ } else {
+ // similar to ggml_backend_compare_graph_backend
+ for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {
+ struct ggml_tensor * t = split->graph.nodes[j0];
+
+ // check if the user needs data from this node
+ bool need = sched->callback_eval(t, true, sched->callback_eval_user_data);
+
+ int j1 = j0;
+
+ // determine the range [j0, j1] of nodes that can be computed together
+ while (!need && j1 < split->graph.n_nodes - 1) {
+ t = split->graph.nodes[++j1];
+ need = sched->callback_eval(t, true, sched->callback_eval_user_data);
+ }
+
+ struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
+
+ enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv);
+ if (ec != GGML_STATUS_SUCCESS) {
+ return ec;
+ }
+
+ // TODO: pass backend to the callback, then the user can decide if they want to synchronize
+ ggml_backend_synchronize(split_backend);
+
+ if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
+ break;
+ }
+
+ j0 = j1;
+ }
+ }
+
+ // record the event of this copy
+ if (split->n_inputs > 0) {
+ if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
+ ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy]);
+ }
+ }
+ }
+
+ sched->cur_copy = (sched->cur_copy + 1) % sched->n_copies;
+
+ return GGML_STATUS_SUCCESS;
+}
+
+ggml_backend_sched_t ggml_backend_sched_new(
+ ggml_backend_t * backends,
+ ggml_backend_buffer_type_t * bufts,
+ int n_backends,
+ size_t graph_size,
+ bool parallel) {
+ GGML_ASSERT(n_backends > 0);
+ GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
+ GGML_ASSERT(ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU
+
+ struct ggml_backend_sched * sched = calloc(1, sizeof(struct ggml_backend_sched));
+
+ sched->debug = getenv("GGML_SCHED_DEBUG") != NULL;
+
+ // initialize hash table
+ sched->hash_set = ggml_hash_set_new(graph_size);
+ sched->tensor_backend_id = calloc(sched->hash_set.size, sizeof(sched->tensor_backend_id[0]));
+ sched->tensor_copies = calloc(sched->hash_set.size, sizeof(sched->tensor_copies[0]));
+
+ const size_t nodes_size = graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2;
+ sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0]));
+ sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0]));
+ sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0]));
+ sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0]));
+
+ sched->n_backends = n_backends;
+
+ sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;
+
+ const int initial_splits_capacity = 16;
+ sched->splits = calloc(initial_splits_capacity, sizeof(sched->splits[0]));
+ sched->splits_capacity = initial_splits_capacity;
+
+ for (int b = 0; b < n_backends; b++) {
+ sched->backends[b] = backends[b];
+ sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]);
+ GGML_ASSERT(ggml_backend_supports_buft(backends[b], sched->bufts[b]));
+ if (sched->n_copies > 1) {
+ for (int c = 0; c < sched->n_copies; c++) {
+ sched->events[b][c] = ggml_backend_event_new(backends[b]);
+ }
+ }
+ }
+
+ sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
+
+ ggml_backend_sched_reset(sched);
+
+ return sched;
+}
+
+void ggml_backend_sched_free(ggml_backend_sched_t sched) {
+ if (sched == NULL) {
+ return;
+ }
+ for (int b = 0; b < sched->n_backends; b++) {
+ for (int c = 0; c < sched->n_copies; c++) {
+ ggml_backend_event_free(sched->events[b][c]);
+ }
+ }
+ ggml_gallocr_free(sched->galloc);
+ ggml_free(sched->ctx);
+ free(sched->splits);
+ free(sched->hash_set.keys);
+ free(sched->tensor_backend_id);
+ free(sched->tensor_copies);
+ free(sched->node_backend_ids);
+ free(sched->leaf_backend_ids);
+ free(sched->prev_node_backend_ids);
+ free(sched->prev_leaf_backend_ids);
+ free(sched);
+}
+
+void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
+ // reset state for the next run
+ if (!sched->is_reset) {
+ size_t hash_size = sched->hash_set.size;
+ memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); // NOLINT
+ memset(sched->tensor_backend_id, -1, sizeof(sched->tensor_backend_id[0]) * hash_size);
+ memset(sched->tensor_copies, 0, sizeof(sched->tensor_copies[0]) * hash_size);
+
+ sched->is_reset = true;
+ }
+ sched->is_alloc = false;
+}
+
+bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
+ GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes);
+
+ ggml_backend_sched_split_graph(sched, measure_graph);
+
+ // TODO: extract this to a separate function
+ if (!ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) {
+ return false;
+ }
+
+ ggml_backend_sched_reset(sched);
+ ggml_backend_sched_synchronize(sched);
+
+ return true;
+}
+
+bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+ GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes);
+
+ ggml_backend_sched_split_graph(sched, graph);
+
+ if (!ggml_backend_sched_alloc_splits(sched)) {
+ return false;
+ }
+
+ sched->is_alloc = true;
+
+ return true;
+}
+
+enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+ enum ggml_status err = ggml_backend_sched_graph_compute_async(sched, graph);
+ ggml_backend_sched_synchronize(sched);
+ return err;
+}
+
+enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+ if (!sched->is_reset && !sched->is_alloc) {
+ ggml_backend_sched_reset(sched);
+ }
+
+ if (!sched->is_alloc) {
+ if (!ggml_backend_sched_alloc_graph(sched, graph)) {
+ return GGML_STATUS_ALLOC_FAILED;
+ }
+ }
+
+ return ggml_backend_sched_compute_splits(sched);
+}
+
+void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
+ for (int i = 0; i < sched->n_backends; i++) {
+ ggml_backend_synchronize(sched->backends[i]);
+ }
+}
+
+void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
+ sched->callback_eval = callback;
+ sched->callback_eval_user_data = user_data;
+}
+
+int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
+ return sched->n_splits;
+}
+
+int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) {
+ return sched->n_copies;
+}
+
+int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) {
+ return sched->n_backends;
+}
+
+ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) {
+ GGML_ASSERT(i >= 0 && i < sched->n_backends);
+ return sched->backends[i];
+}
+
+size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
+ int backend_index = ggml_backend_sched_backend_id(sched, backend);
+ GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
+
+ return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
+}
+
+void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
+ int backend_index = ggml_backend_sched_backend_id(sched, backend);
+ GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
+ tensor_backend_id(node) = backend_index;
+ SET_CAUSE(node, "usr");
+}
+
+ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) {
+ int backend_index = tensor_backend_id(node);
+ if (backend_index == -1) {
+ return NULL;
+ }
+ return sched->backends[backend_index];
+}
+
+// utils
+
+void ggml_backend_view_init(struct ggml_tensor * tensor) {
+ GGML_ASSERT(tensor->buffer == NULL);
+ GGML_ASSERT(tensor->view_src != NULL);
+ GGML_ASSERT(tensor->view_src->buffer != NULL);
+ GGML_ASSERT(tensor->view_src->data != NULL);
+
+ tensor->buffer = tensor->view_src->buffer;
+ tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
+ ggml_backend_buffer_init_tensor(tensor->buffer, tensor);
+}
+
+void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
+ GGML_ASSERT(tensor->buffer == NULL);
+ GGML_ASSERT(tensor->data == NULL);
+ GGML_ASSERT(tensor->view_src == NULL);
+ GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer));
+ GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
+ (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));
+
+ tensor->buffer = buffer;
+ tensor->data = addr;
+ ggml_backend_buffer_init_tensor(buffer, tensor);
+}
+
+static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies,
+ struct ggml_context * ctx_allocated, struct ggml_context * ctx_unallocated, struct ggml_tensor * src) {
+
+ GGML_ASSERT(src != NULL);
+ GGML_ASSERT(src->data && "graph must be allocated");
+
+ size_t id = ggml_hash_insert(hash_set, src);
+ if (id == GGML_HASHTABLE_ALREADY_EXISTS) {
+ return node_copies[ggml_hash_find(hash_set, src)];
+ }
+
+ struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src);
+ if (src->view_src != NULL) {
+ dst->view_src = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src);
+ dst->view_offs = src->view_offs;
+ }
+ dst->op = src->op;
+ memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
+ ggml_set_name(dst, src->name);
+
+ // copy src
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ struct ggml_tensor * s = src->src[i];
+ if (s == NULL) {
+ continue;
+ }
+ dst->src[i] = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
+ }
+
+ node_copies[id] = dst;
+ return dst;
+}
+
+static void graph_copy_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) {
+ size_t id = ggml_hash_find(hash_set, src);
+ if (node_init[id]) {
+ return;
+ }
+ node_init[id] = true;
+
+ struct ggml_tensor * dst = node_copies[id];
+ if (dst->view_src != NULL) {
+ graph_copy_init_tensor(hash_set, node_copies, node_init, src->view_src);
+ ggml_backend_view_init(dst);
+ }
+ else {
+ ggml_backend_tensor_copy(src, dst);
+ }
+
+ // init src
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ struct ggml_tensor * s = src->src[i];
+ if (s == NULL) {
+ continue;
+ }
+ graph_copy_init_tensor(hash_set, node_copies, node_init, s);
+ }
+}
+
+struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
+ struct ggml_hash_set hash_set = {
+ /* .size = */ graph->visited_hash_table.size,
+ /* .keys = */ calloc(graph->visited_hash_table.size, sizeof(hash_set.keys[0])) // NOLINT
+ };
+ struct ggml_tensor ** node_copies = calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT
+ bool * node_init = calloc(hash_set.size, sizeof(node_init[0]));
+
+ struct ggml_init_params params = {
+ /* .mem_size = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false),
+ /* .mem_buffer = */ NULL,
+ /* .no_alloc = */ true
+ };
+
+ struct ggml_context * ctx_allocated = ggml_init(params);
+ struct ggml_context * ctx_unallocated = ggml_init(params);
+
+ if (ctx_allocated == NULL || ctx_unallocated == NULL) {
+ fprintf(stderr, "failed to allocate context for graph copy\n");
+ free(hash_set.keys);
+ free(node_copies);
+ free(node_init);
+ ggml_free(ctx_allocated);
+ ggml_free(ctx_unallocated);
+ return (struct ggml_backend_graph_copy) {
+ /* .buffer = */ NULL,
+ /* .ctx_allocated = */ NULL,
+ /* .ctx_unallocated = */ NULL,
+ /* .graph = */ NULL,
+ };
+ }
+
+ // dup nodes
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node);
+ }
+
+ // allocate nodes
+ ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
+ if (buffer == NULL) {
+ fprintf(stderr, "failed to allocate buffer for graph copy\n");
+ free(hash_set.keys);
+ free(node_copies);
+ free(node_init);
+ ggml_free(ctx_allocated);
+ ggml_free(ctx_unallocated);
+ return (struct ggml_backend_graph_copy) {
+ /* .buffer = */ NULL,
+ /* .ctx_allocated = */ NULL,
+ /* .ctx_unallocated = */ NULL,
+ /* .graph = */ NULL,
+ };
+ }
+
+ //printf("copy buffer size: %zu MB\n", ggml_backend_buffer_get_size(buffer) / 1024 / 1024);
+
+ // copy data and init views
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ graph_copy_init_tensor(hash_set, node_copies, node_init, node);
+ }
+
+ // build graph copy
+ struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false);
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ struct ggml_tensor * node_copy = node_copies[ggml_hash_find(hash_set, node)];
+ graph_copy->nodes[i] = node_copy;
+ }
+ graph_copy->n_nodes = graph->n_nodes;
+
+ free(hash_set.keys);
+ free(node_copies);
+ free(node_init);
+
+ return (struct ggml_backend_graph_copy) {
+ /* .buffer = */ buffer,
+ /* .ctx_allocated = */ ctx_allocated,
+ /* .ctx_unallocated = */ ctx_unallocated,
+ /* .graph = */ graph_copy,
+ };
+}
+
+void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
+ ggml_backend_buffer_free(copy.buffer);
+ ggml_free(copy.ctx_allocated);
+ ggml_free(copy.ctx_unallocated);
+}
+
+bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
+ struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
+ if (copy.buffer == NULL) {
+ return false;
+ }
+
+ struct ggml_cgraph * g1 = graph;
+ struct ggml_cgraph * g2 = copy.graph;
+
+ assert(g1->n_nodes == g2->n_nodes);
+
+ for (int i = 0; i < g1->n_nodes; i++) {
+ //printf("eval %d/%d\n", i, g1->n_nodes);
+ struct ggml_tensor * t1 = g1->nodes[i];
+ struct ggml_tensor * t2 = g2->nodes[i];
+
+ assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
+
+ struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
+ struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
+
+ ggml_backend_graph_compute(backend1, &g1v);
+ ggml_backend_graph_compute(backend2, &g2v);
+
+ if (ggml_is_view_op(t1->op)) {
+ continue;
+ }
+
+ // compare results, calculate rms etc
+ if (!callback(i, t1, t2, user_data)) {
+ break;
+ }
+ }
+
+ ggml_backend_graph_copy_free(copy);
+
+ return true;
+}
diff --git a/ggml/src/ggml-blas.cpp b/ggml/src/ggml-blas.cpp
new file mode 100644
index 00000000..a37aa407
--- /dev/null
+++ b/ggml/src/ggml-blas.cpp
@@ -0,0 +1,368 @@
+#include "ggml-blas.h"
+#include "ggml-backend-impl.h"
+
+#include <future>
+#include <vector>
+
+#if defined(GGML_USE_ACCELERATE)
+# include <Accelerate/Accelerate.h>
+#elif defined(GGML_BLAS_USE_MKL)
+# include <mkl.h>
+#elif defined(GGML_BLAS_USE_BLIS)
+# include <blis.h>
+#elif defined(GGML_BLAS_USE_NVPL)
+# include <nvpl_blas.h>
+#else
+# include <cblas.h>
+#endif
+
+struct ggml_backend_blas_context {
+ int n_threads = GGML_DEFAULT_N_THREADS;
+ std::unique_ptr<char[]> work_data;
+ size_t work_size = 0;
+#ifndef GGML_USE_OPENMP
+ std::vector<std::future<void>> tasks;
+#endif
+};
+
+// helper function to determine if it is better to use BLAS or not
+// for large matrices, BLAS is faster
+static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) {
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ const int64_t ne10 = src1->ne[0];
+
+ const int64_t ne0 = dst->ne[0];
+ const int64_t ne1 = dst->ne[1];
+
+ // TODO: find the optimal values for these
+ if (ggml_is_contiguous(src0) &&
+ ggml_is_contiguous(src1) &&
+ src1->type == GGML_TYPE_F32 &&
+ (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
+
+ /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
+ return true;
+ }
+
+ return false;
+}
+
+static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const enum ggml_type type = src0->type;
+
+ GGML_ASSERT(ne0 == ne01);
+ GGML_ASSERT(ne1 == ne11);
+ GGML_ASSERT(ne2 == ne12);
+ GGML_ASSERT(ne3 == ne13);
+
+ // we don't support permuted src0 or src1
+ GGML_ASSERT(nb00 == ggml_type_size(type));
+ GGML_ASSERT(nb10 == ggml_type_size(src1->type));
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ // broadcast factors
+ const int64_t r2 = ne12/ne02;
+ const int64_t r3 = ne13/ne03;
+
+ const int64_t ne_plane = ne01*ne00;
+ const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
+
+ if (ctx->work_size < desired_wsize) {
+ ctx->work_data.reset(new char[desired_wsize]);
+ ctx->work_size = desired_wsize;
+ }
+ void * wdata = ctx->work_data.get();
+
+ // convert src0 to float
+ if (type != GGML_TYPE_F32) {
+ ggml_type_traits_t type_traits = ggml_internal_get_type_traits(type);
+ ggml_to_float_t const to_float = type_traits.to_float;
+
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
+ float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
+
+ const int min_cols_per_thread = 4096;
+ const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1);
+ const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1);
+
+#ifdef GGML_USE_OPENMP
+ #pragma omp parallel for num_threads(n_threads)
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
+ to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
+ }
+#else
+ for (int i = 1; i < n_threads; i++) {
+ const int64_t start = i*ne01/n_threads;
+ const int64_t end = (i + 1)*ne01/n_threads;
+ if (start < end) {
+ ctx->tasks.push_back(std::async(std::launch::async, [=]() {
+ for (int64_t i01 = start; i01 < end; i01++) {
+ to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
+ }
+ }));
+ }
+ }
+ {
+ // reuse the current thread for the first task
+ const int64_t start = 0;
+ const int64_t end = ne01/n_threads;
+ for (int64_t i01 = start; i01 < end; i01++) {
+ to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
+ }
+ }
+#endif
+ }
+ }
+
+#ifndef GGML_USE_OPENMP
+ // wait for all tasks to finish
+ for (auto & task : ctx->tasks) {
+ task.get();
+ }
+ ctx->tasks.clear();
+#endif
+ }
+
+#if defined(OPENBLAS_VERSION)
+ openblas_set_num_threads(ctx->n_threads);
+#endif
+
+#if defined(GGML_BLAS_USE_BLIS)
+ bli_thread_set_num_threads(ctx->n_threads);
+#endif
+
+#if defined(GGML_BLAS_USE_NVPL)
+ nvpl_blas_set_num_threads(ctx->n_threads);
+#endif
+
+ for (int64_t i13 = 0; i13 < ne13; i13++) {
+ for (int64_t i12 = 0; i12 < ne12; i12++) {
+ const int64_t i03 = i13/r3;
+ const int64_t i02 = i12/r2;
+
+ const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
+ const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
+ float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
+
+ if (type != GGML_TYPE_F32) {
+ x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
+ }
+
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+ ne1, ne01, ne10,
+ 1.0f, y, ne10,
+ x, ne00,
+ 0.0f, d, ne01);
+ }
+ }
+}
+
+static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ GGML_ASSERT(ne0 == ne00);
+ GGML_ASSERT(ne1 == ne10);
+ GGML_ASSERT(ne2 == ne02);
+ GGML_ASSERT(ne02 == ne12);
+ GGML_ASSERT(ne3 == ne13);
+ GGML_ASSERT(ne03 == ne13);
+
+ // we don't support permuted src0 or src1
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ // GGML_ASSERT(nb0 <= nb1);
+ // GGML_ASSERT(nb1 <= nb2);
+ // GGML_ASSERT(nb2 <= nb3);
+
+ // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
+ // src0: (k,n)
+ // src1: (k,m)
+ // dst: (m,n)
+ //
+ // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
+ // Also expressed as (major,minor)
+ // a: (m,k): so src1 transposed
+ // b: (k,n): so src0
+ // c: (m,n)
+ //
+ // However, if ggml_is_transposed(src1) is true, then
+ // src1->data already contains a transposed version, so sgemm mustn't
+ // transpose it further.
+
+ int n = src0->ne[0];
+ int k = src0->ne[1];
+ int m = src1->ne[0];
+
+ CBLAS_TRANSPOSE transposeA;
+ int lda;
+
+ if (!ggml_is_transposed(src1)) {
+ transposeA = CblasTrans;
+ lda = m;
+ } else {
+ transposeA = CblasNoTrans;
+ lda = k;
+ }
+
+ float * a = (float *) ((char *) src1->data);
+ float * b = (float *) ((char *) src0->data);
+ float * c = (float *) ((char *) dst->data);
+
+ cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
+
+ GGML_UNUSED(ctx);
+}
+
+// backend interface
+
+GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) {
+ return "BLAS";
+
+ GGML_UNUSED(backend);
+}
+
+GGML_CALL static void ggml_backend_blas_free(ggml_backend_t backend) {
+ ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
+ delete ctx;
+ delete backend;
+}
+
+GGML_CALL static ggml_backend_buffer_type_t ggml_backend_blas_get_default_buffer_type(ggml_backend_t backend) {
+ return ggml_backend_cpu_buffer_type();
+
+ GGML_UNUSED(backend);
+}
+
+GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+ ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ struct ggml_tensor * node = cgraph->nodes[i];
+
+ switch (node->op) {
+ case GGML_OP_MUL_MAT:
+ ggml_backend_blas_mul_mat(ctx, node);
+ break;
+
+ case GGML_OP_OUT_PROD:
+ ggml_backend_blas_out_prod(ctx, node);
+ break;
+
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ break;
+
+ default:
+ fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node));
+ GGML_ASSERT(false);
+ }
+ }
+
+ return GGML_STATUS_SUCCESS;
+
+ GGML_UNUSED(backend);
+}
+
+GGML_CALL static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0];
+ const struct ggml_tensor * src1 = op->src[1];
+
+ return (op->op == GGML_OP_MUL_MAT && ggml_backend_blas_use_blas(op)) ||
+ (op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 &&
+ op->src[1]->type == GGML_TYPE_F32 &&
+ ggml_is_matrix(src0) &&
+ ggml_is_matrix(src1) &&
+ ggml_is_contiguous(src0) &&
+ (ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
+
+ GGML_UNUSED(backend);
+}
+
+GGML_CALL static bool ggml_backend_blas_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+ return ggml_backend_buft_is_host(buft);
+
+ GGML_UNUSED(backend);
+}
+
+static struct ggml_backend_i blas_backend_i = {
+ /* .get_name = */ ggml_backend_blas_name,
+ /* .free = */ ggml_backend_blas_free,
+ /* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type,
+ /* .set_tensor_async = */ NULL,
+ /* .get_tensor_async = */ NULL,
+ /* .cpy_tensor_async = */ NULL,
+ /* .synchronize = */ NULL,
+ /* .graph_plan_create = */ NULL,
+ /* .graph_plan_free = */ NULL,
+ /* .graph_plan_update = */ NULL,
+ /* .graph_plan_compute = */ NULL,
+ /* .graph_compute = */ ggml_backend_blas_graph_compute,
+ /* .supports_op = */ ggml_backend_blas_supports_op,
+ /* .supports_buft = */ ggml_backend_blas_supports_buft,
+ /* .offload_op = */ NULL,
+ /* .event_new = */ NULL,
+ /* .event_free = */ NULL,
+ /* .event_record = */ NULL,
+ /* .event_wait = */ NULL,
+ /* .event_synchronize = */ NULL,
+};
+
+static ggml_guid_t ggml_backend_blas_guid(void) {
+ static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d };
+ return &guid;
+}
+
+ggml_backend_t ggml_backend_blas_init(void) {
+ ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
+
+ ggml_backend_t backend = new ggml_backend {
+ /* .guid = */ ggml_backend_blas_guid(),
+ /* .interface = */ blas_backend_i,
+ /* .context = */ ctx,
+ };
+
+#if !defined(NDEBUG) && defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
+ if (openblas_get_parallel() != OPENBLAS_OPENMP) {
+ fprintf(stderr, "%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
+ }
+#endif
+
+#if !defined(NDEBUG) && defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
+ fprintf(stderr, "%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
+#endif
+
+ return backend;
+}
+
+GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend) {
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid());
+}
+
+void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) {
+ GGML_ASSERT(ggml_backend_is_blas(backend_blas));
+
+ ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
+ ctx->n_threads = n_threads;
+}
diff --git a/ggml/src/ggml-cann.cpp b/ggml/src/ggml-cann.cpp
new file mode 100644
index 00000000..9bf7e332
--- /dev/null
+++ b/ggml/src/ggml-cann.cpp
@@ -0,0 +1,2023 @@
+/*
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#include "ggml-cann.h"
+
+#include <acl/acl.h>
+#include <stdarg.h>
+
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <mutex>
+
+#include "ggml-backend-impl.h"
+#include "ggml-cann/aclnn_ops.h"
+#include "ggml-cann/common.h"
+
+#define GGML_COMMON_DECL_C
+
+#include "ggml-common.h"
+
+/**
+ * @brief Default logging callback for GGML.
+ *
+ * This function is the default logging callback that logs messages to stderr.
+ *
+ * @param level The log level.
+ * @param msg The log message.
+ * @param user_data User data passed to the callback.
+ */
+static void ggml_cann_default_log_callback(enum ggml_log_level level,
+ const char* msg, void* user_data) {
+ GGML_UNUSED(level);
+ GGML_UNUSED(user_data);
+ fprintf(stderr, "%s", msg);
+}
+
+ggml_log_callback ggml_cann_log_callback = ggml_cann_default_log_callback;
+void* ggml_cann_log_user_data = NULL;
+
+GGML_API void ggml_backend_cann_log_set_callback(ggml_log_callback log_callback,
+ void* user_data) {
+ ggml_cann_log_callback = log_callback;
+ ggml_cann_log_user_data = user_data;
+}
+
+#define GGML_CANN_LOG_INFO(...) ggml_cann_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
+#define GGML_CANN_LOG_WARN(...) ggml_cann_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
+#define GGML_CANN_LOG_ERROR(...) \
+ ggml_cann_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+
+GGML_ATTRIBUTE_FORMAT(2, 3)
+
+/**
+ * @brief Log a message using the current logging callback.
+ *
+ * This function formats a log message and passes it to the current logging
+ * callback.
+ *
+ * @param level The log level.
+ * @param format The format string for the log message.
+ * @param ... The arguments for the format string.
+ */
+static void ggml_cann_log(enum ggml_log_level level, const char* format, ...) {
+ if (ggml_cann_log_callback != NULL) {
+ va_list args;
+ va_start(args, format);
+ char buffer[128];
+ int len = vsnprintf(buffer, 128, format, args);
+ if (len < 128) {
+ ggml_cann_log_callback(level, buffer, ggml_cann_log_user_data);
+ } else {
+ // vsnprintf adds a null terminator
+ std::vector<char> buffer2(len + 1);
+ va_end(args);
+ va_start(args, format);
+ vsnprintf(&buffer2[0], buffer2.size(), format, args);
+ ggml_cann_log_callback(level, buffer2.data(),
+ ggml_cann_log_user_data);
+ }
+ va_end(args);
+ }
+}
+
+/**
+ * @brief Handles CANN errors by printing an error message and aborting.
+ *
+ * @param stmt The statement that caused the error.
+ * @param func The function in which the error occurred.
+ * @param file The file in which the error occurred.
+ * @param line The line number where the error occurred.
+ * @param msg The error message.
+ */
+[[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
+ const char* file, int line, const char* msg) {
+ int32_t id = -1;
+ aclrtGetDevice(&id);
+
+ GGML_CANN_LOG_ERROR("CANN error: %s\n", msg);
+ GGML_CANN_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func,
+ file, line);
+ GGML_CANN_LOG_ERROR(" %s\n", stmt);
+ // abort with GGML_ASSERT to get a stack trace
+ GGML_ASSERT(!"CANN error");
+}
+
+/**
+ * @brief Sets the device to be used by CANN.
+ *
+ * @param device The device ID to set.
+ */
+void ggml_cann_set_device(const int32_t device) {
+ // TODO: uncomment these lines after empty context has fixed.
+ // int current_device;
+ // ACL_CHECK(aclrtGetDevice(&current_device));
+
+ // if (device == current_device) {
+ // return;
+ // }
+ ACL_CHECK(aclrtSetDevice(device));
+}
+
+/**
+ * @brief Retrieves the current device ID.
+ *
+ * @return The current device ID.
+ */
+int32_t ggml_cann_get_device() {
+ int32_t id;
+ ACL_CHECK(aclrtGetDevice(&id));
+ return id;
+}
+
+/**
+ * @brief Initialize the CANN device information.
+ *
+ * This function initializes the CANN device information by obtaining the
+ * device count and setting the memory allocation granularity for each device.
+ *
+ * @return A structure containing the device information.
+ */
+static ggml_cann_device_info ggml_cann_init() {
+ ggml_cann_device_info info = {};
+
+ aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
+
+ if (err != ACL_SUCCESS) {
+ GGML_CANN_LOG_ERROR("%s: failed to initialize CANN: %s\n",
+ __func__, aclGetRecentErrMsg());
+ return info;
+ }
+
+ GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES);
+
+ for (int id = 0; id < info.device_count; ++id) {
+ aclrtPhysicalMemProp prop = {};
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
+ prop.memAttr = ACL_HBM_MEM_HUGE;
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
+ prop.location.id = id;
+ prop.reserve = 0;
+ ACL_CHECK(aclrtMemGetAllocationGranularity(
+ &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
+ &info.devices[id].vmm_granularity));
+ }
+
+ // TODO: add more device info later.
+ return info;
+}
+
+/**
+ * @brief Retrieve the CANN device information.
+ *
+ * This function returns a reference to a structure containing the CANN device
+ * information. The device information is initialized once and reused on
+ * subsequent calls.
+ *
+ * @return A reference to the structure containing the device information.
+ */
+const ggml_cann_device_info& ggml_cann_info() {
+ static ggml_cann_device_info info = ggml_cann_init();
+ return info;
+}
+
+//#define DEBUG_CANN_MALLOC
+/**
+ * @brief A pool of CANN buffers(legacy).
+ *
+ * This class manages a pool of CANN buffers for a specific device.
+ */
+struct ggml_cann_pool_leg : public ggml_cann_pool {
+ /**
+ * @brief The maximum number of buffers in the pool.
+ */
+ static const int MAX_BUFFERS = 256;
+
+ /**
+ * @brief The device ID associated with this buffer pool.
+ */
+ int device;
+
+ /**
+ * @brief Structure representing a CANN buffer.
+ */
+ struct ggml_cann_buffer {
+ void* ptr = nullptr; ///< Pointer to the buffer memory.
+ size_t size = 0; ///< Size of the buffer.
+ };
+
+ /**
+ * @brief Array of CANN buffers in the pool.
+ */
+ ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {};
+
+ /**
+ * @brief Total size of all buffers in the pool.
+ */
+ size_t pool_size = 0;
+
+ /**
+ * @brief Constructor to initialize the buffer pool for a specific device.
+ *
+ * @param device The device ID to associate with this buffer pool.
+ */
+ explicit ggml_cann_pool_leg(int device) : device(device) {}
+
+ /**
+ * @brief Destructor to free all buffers in the pool.
+ */
+ ~ggml_cann_pool_leg() {
+ ggml_cann_set_device(device);
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
+ ggml_cann_buffer& b = buffer_pool[i];
+ if (b.ptr != nullptr) {
+ ACL_CHECK(aclrtFree(b.ptr));
+ pool_size -= b.size;
+ }
+ }
+ GGML_ASSERT(pool_size == 0);
+ }
+
+ /**
+ * @brief Allocate a buffer of the given size.
+ *
+ * @param size The size of the buffer to allocate.
+ * @param actual_size A pointer to a variable to receive the actual size of
+ * the allocated buffer.
+ * @return A pointer to the allocated buffer.
+ */
+ void* alloc(size_t size, size_t* actual_size) override {
+#ifdef DEBUG_CANN_MALLOC
+ int nnz = 0;
+ size_t max_size = 0;
+#endif
+ size_t best_diff = 1ull << 36;
+ int ibest = -1;
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
+ ggml_cann_buffer& b = buffer_pool[i];
+ if (b.ptr != nullptr) {
+#ifdef DEBUG_CANN_MALLOC
+ ++nnz;
+ if (b.size > max_size) max_size = b.size;
+#endif
+ if (b.size >= size) {
+ size_t diff = b.size - size;
+ if (diff < best_diff) {
+ best_diff = diff;
+ ibest = i;
+ if (!best_diff) {
+ void* ptr = b.ptr;
+ *actual_size = b.size;
+ b.ptr = nullptr;
+ b.size = 0;
+ return ptr;
+ }
+ }
+ }
+ }
+ }
+ if (ibest >= 0) {
+ ggml_cann_buffer& b = buffer_pool[ibest];
+ void* ptr = b.ptr;
+ *actual_size = b.size;
+ b.ptr = nullptr;
+ b.size = 0;
+ return ptr;
+ }
+ void* ptr;
+ size_t look_ahead_size = (size_t)(1.05 * size);
+ look_ahead_size = 256 * ((look_ahead_size + 255) / 256);
+ ggml_cann_set_device(device);
+ ACL_CHECK(
+ aclrtMalloc(&ptr, look_ahead_size, ACL_MEM_MALLOC_HUGE_FIRST));
+ *actual_size = look_ahead_size;
+ pool_size += look_ahead_size;
+#ifdef DEBUG_CANN_MALLOC
+ GGML_CANN_LOG_INFO(
+ "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
+ "requested %u MB\n",
+ __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024),
+ (uint32_t)(pool_size / 1024 / 1024),
+ (uint32_t)(size / 1024 / 1024));
+#endif
+ return ptr;
+ }
+
+ /**
+ * @brief Free a buffer and return it to the pool.
+ *
+ * @param ptr Pointer to the buffer to free.
+ * @param size Size of the buffer to free.
+ */
+ void free(void* ptr, size_t size) override {
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
+ ggml_cann_buffer& b = buffer_pool[i];
+ if (b.ptr == nullptr) {
+ b.ptr = ptr;
+ b.size = size;
+ return;
+ }
+ }
+ // memory should always buffered. these memory may still needed by
+ // tasks in stream.
+ // TODO, fix me.
+ GGML_ASSERT(!"Cann buffer pool full, increase MAX_CANN_BUFFERS\n");
+ }
+};
+
+/**
+ * @brief A pool of CANN buffers with virtual memory.
+ *
+ * This class manages a pool of CANN buffers with virtual memory for a specific
+ * device.
+ */
+struct ggml_cann_pool_vmm : public ggml_cann_pool {
+ /**
+ * @brief The maximum size of the virtual memory pool (32 GB).
+ */
+ static const size_t CANN_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
+
+ /**
+ * @brief The device ID associated with this buffer pool.
+ */
+ int device;
+
+ /**
+ * @brief Pointer to the start of the virtual memory pool.
+ */
+ void* pool_addr = 0;
+
+ /**
+ * @brief Amount of virtual memory used in the pool.
+ */
+ size_t pool_used = 0;
+
+ /**
+ * @brief Total size of the virtual memory pool.
+ */
+ size_t pool_size = 0;
+
+ /**
+ * @brief Allocation granularity for the virtual memory pool.
+ */
+ size_t granularity;
+
+ /**
+ * @brief Handles for the physical memory allocated.
+ */
+ std::vector<aclrtDrvMemHandle> handles;
+
+ /**
+ * @brief Offsets for the mapped memory regions.
+ */
+ std::vector<void*> map_offsets;
+
+ /**
+ * @brief Constructor to initialize the buffer pool with virtual memory for
+ * a specific device.
+ *
+ * @param device The device ID to associate with this buffer pool.
+ */
+ explicit ggml_cann_pool_vmm(int device)
+ : device(device),
+ granularity(ggml_cann_info().devices[device].vmm_granularity) {}
+
+ /**
+ * @brief Destructor to free all buffers in the virtual memory pool.
+ */
+ ~ggml_cann_pool_vmm() {
+ if (pool_addr != 0) {
+ for (auto& offset : map_offsets) {
+ ACL_CHECK(aclrtUnmapMem(offset));
+ }
+ for (auto& handle : handles) {
+ ACL_CHECK(aclrtFreePhysical(handle));
+ }
+ ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
+ }
+ }
+
+ /**
+ * @brief Allocate a buffer of the given size in the virtual memory pool.
+ *
+ * @param size The size of the buffer to allocate.
+ * @param actual_size A pointer to a variable to receive the actual size of
+ * the allocated buffer.
+ * @return A pointer to the allocated buffer.
+ */
+ void* alloc(size_t size, size_t* actual_size) override {
+ // round up the allocation size to the alignment to ensure that all
+ // allocations are aligned for all data types
+ const size_t alignment = 128;
+ size = alignment * ((size + alignment - 1) / alignment);
+
+ size_t avail = pool_size - pool_used;
+
+ if (size > avail) {
+ // round up to the next multiple of the granularity
+ size_t reserve_size = size - avail;
+ reserve_size =
+ granularity * ((reserve_size + granularity - 1) / granularity);
+
+ GGML_ASSERT(pool_size + reserve_size <= CANN_POOL_VMM_MAX_SIZE);
+
+ // allocate more physical memory
+ aclrtPhysicalMemProp prop = {};
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
+ prop.memAttr = ACL_HBM_MEM_HUGE;
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
+ prop.location.id = device;
+ prop.reserve = 0;
+ aclrtDrvMemHandle handle;
+ ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
+
+ // reserve virtual address space (if not already reserved)
+ if (pool_addr == 0) {
+ ACL_CHECK(aclrtReserveMemAddress(
+ &pool_addr, CANN_POOL_VMM_MAX_SIZE, 0, NULL, 1));
+ }
+
+ // map at the end of the pool
+ ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0,
+ handle, 0));
+
+ handles.push_back(handle);
+ map_offsets.push_back((char*)pool_addr + pool_size);
+
+ // add to the pool
+ pool_size += reserve_size;
+
+ // GGML_CANN_LOG_INFO("cann pool[%d]: size increased to %llu MB (
+ // reserved %llu MB)\n",
+ // device, (unsigned long long) (pool_size/1024/1024),
+ // (unsigned long long) (reserve_size/1024/1024));
+ }
+
+ GGML_ASSERT(pool_addr != 0);
+
+ void* ptr = (void*)((char*)pool_addr + pool_used);
+ *actual_size = size;
+ pool_used += size;
+
+#ifdef DEBUG_CANN_MALLOC
+ GGML_CANN_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device,
+ (unsigned long long)size, (unsigned long long)ptr);
+#endif
+ return ptr;
+ }
+
+ /**
+ * @brief Free a buffer and return it to the virtual memory pool.
+ *
+ * @param ptr Pointer to the buffer to free.
+ * @param size Size of the buffer to free.
+ */
+ void free(void* ptr, size_t size) override {
+#ifdef DEBUG_CANN_MALLOC
+ GGML_CANN_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device,
+ (unsigned long long)size, (unsigned long long)ptr);
+#endif
+
+ pool_used -= size;
+
+ // all deallocations must be in reverse order of the allocations
+ GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used));
+ }
+};
+
+/**
+ * @brief Create a new CANN pool for a specific device.
+ *
+ * Factory method to create a new CANN pool object based on the device type.
+ *
+ * @param device The device ID for which to create the pool.
+ * @return A unique pointer to the created CANN pool.
+ */
+std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
+ int device) {
+ // return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
+ return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
+}
+
+// cann buffer
+/**
+ * @brief Context for managing a CANN buffer associated with a specific device.
+ *
+ * This structure holds information about a CANN buffer, including the device
+ * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
+ */
+struct ggml_backend_cann_buffer_context {
+ int32_t device; ///< The device ID associated with this buffer context.
+ void* dev_ptr =
+ nullptr; ///< Pointer to the device memory allocated for the buffer.
+
+ /**
+ * @brief Constructor to initialize the CANN buffer context.
+ *
+ * @param device The device ID associated with this buffer context.
+ * @param dev_ptr Pointer to the device memory allocated for the buffer.
+ */
+ ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
+ : device(device),
+ dev_ptr(dev_ptr) {}
+
+ /**
+ * @brief Destructor to free the device memory allocated for the buffer.
+ */
+ ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
+};
+
+/**
+ * @brief Retrieve the name associated with a CANN buffer.
+ *
+ * This function returns the name of a CANN buffer, which is stored in the
+ * context of the buffer.
+ *
+ * @param buffer The CANN buffer whose name is to be retrieved.
+ * @return A pointer to a C-string containing the name of the buffer.
+ */
+
+GGML_CALL static const char* ggml_backend_cann_buffer_get_name(
+ ggml_backend_buffer_t buffer) {
+ return "CANN";
+
+ GGML_UNUSED(buffer);
+}
+
+/**
+ * @brief Check if a buffer is a CANN buffer.
+ *
+ * This function checks if a given buffer is a CANN buffer by comparing its
+ * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
+ *
+ * @param buffer The buffer to check.
+ * @return true if the buffer is a CANN buffer, false otherwise.
+ */
+GGML_CALL static bool ggml_backend_buffer_is_cann(
+ ggml_backend_buffer_t buffer) {
+ return buffer->iface.get_name == ggml_backend_cann_buffer_get_name;
+}
+
+/**
+ * @brief Free resources associated with a CANN buffer.
+ *
+ * This function frees the resources associated with a CANN buffer, including
+ * its context.
+ *
+ * @param buffer The CANN buffer to free.
+ */
+GGML_CALL static void ggml_backend_cann_buffer_free_buffer(
+ ggml_backend_buffer_t buffer) {
+ ggml_backend_cann_buffer_context* ctx =
+ (ggml_backend_cann_buffer_context*)buffer->context;
+ delete ctx;
+}
+
+/**
+ * @brief Retrieve the base pointer of a CANN buffer.
+ *
+ * This function returns the base pointer of a CANN buffer, which points to the
+ * device memory allocated for the buffer.
+ *
+ * @param buffer The CANN buffer whose base pointer is to be retrieved.
+ * @return A pointer to the base of the device memory allocated for the buffer.
+ */
+GGML_CALL static void* ggml_backend_cann_buffer_get_base(
+ ggml_backend_buffer_t buffer) {
+ ggml_backend_cann_buffer_context* ctx =
+ (ggml_backend_cann_buffer_context*)buffer->context;
+ return ctx->dev_ptr;
+}
+
+/**
+ * @brief Transform quantized Q4.0 tensor data into a format suitable for CANN
+ * processing.
+ *
+ * This function transforms quantized Q4.0 tensor data into a format suitable
+ * for CANN processing. It extracts quantization values and scales from the
+ * source data and prepares them in a format expected by CANN operations.
+ *
+ * @param tensor Pointer to the tensor information.
+ * @param src Pointer to the source data in Q4.0 format.
+ * @param dst Pointer to the destination buffer where transformed data will be
+ * stored.
+ */
+GGML_CALL static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
+ const void* src,
+ void* dst) {
+ GGML_ASSERT(tensor->op == GGML_OP_NONE);
+
+ int64_t n_elems = ggml_nelements(tensor);
+ int64_t groups = n_elems / QK4_0;
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
+
+ uint8_t* quant_offset = (uint8_t*)dst;
+ uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
+
+ for (int i = 0; i < groups; i++) {
+ const block_q4_0* group =
+ (const block_q4_0*)((const char*)src + i * sizeof(block_q4_0));
+ *scale_offset = group->d;
+ scale_offset++;
+
+ // 0-15
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
+ (*quant_offset) = (group->qs[j] & 0x0F);
+ (*quant_offset) |= ((group->qs[j + 1] << 4));
+ quant_offset++;
+ }
+
+ // 16-31
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
+ (*quant_offset) = (group->qs[j] >> 4);
+ (*quant_offset) |= (group->qs[j + 1] & 0xF0);
+ quant_offset++;
+ }
+ }
+
+ // put (uint4b_t -8) into int4b_t
+ for (quant_offset = (uint8_t*)dst;
+ quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) {
+ (*quant_offset) ^= 0x88;
+ }
+}
+
+/**
+ * @brief Transform CANN processed data back into quantized Q4.0 format.
+ *
+ * This function transforms CANN processed data back into quantized Q4.0 format.
+ * It reverses the transformation performed by
+ * ggml_backend_cann_transform_q4_0(), converting the data back into its
+ * original quantized form.
+ *
+ * @param tensor Pointer to the tensor information.
+ * @param src Pointer to the source buffer containing transformed data.
+ * @param dst Pointer to the destination buffer where the Q4.0 formatted data
+ * will be stored.
+ */
+GGML_CALL static void ggml_backend_cann_transform_back_q4_0(
+ const ggml_tensor* tensor, void* src, void* dst) {
+ GGML_ASSERT(tensor->op == GGML_OP_NONE);
+
+ int64_t n_elems = ggml_nelements(tensor);
+ int64_t groups = n_elems / QK4_0;
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
+
+ uint8_t* quant_offset = (uint8_t*)src;
+ uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes);
+
+ for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) {
+ (*quant_offset) ^= 0x88;
+ }
+ quant_offset = (uint8_t*)src;
+
+ for (int i = 0; i < groups; i++) {
+ block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0));
+ group->d = *scale_offset;
+ scale_offset++;
+
+ // 0-15
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
+ group->qs[j] = ((*quant_offset) & 0x0F);
+ group->qs[j + 1] = ((*quant_offset) >> 4);
+ quant_offset++;
+ }
+
+ // 16-31
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
+ group->qs[j] |= ((*quant_offset) << 4);
+ group->qs[j + 1] |= ((*quant_offset) & 0xF0);
+ quant_offset++;
+ }
+ }
+}
+
+/**
+ * @brief Transform quantized Q8.0 tensor data into a format suitable for CANN
+ * processing.
+ *
+ * This function transforms quantized Q8.0 tensor data into a format suitable
+ * for CANN processing. It extracts quantization values and scales from the
+ * source data and prepares them in a format expected by CANN operations.
+ *
+ * @param tensor Pointer to the tensor information.
+ * @param src Pointer to the source data in Q8.0 format.
+ * @param dst Pointer to the destination buffer where transformed data will be
+ * stored.
+ */
+GGML_CALL static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
+ const void* src,
+ void* dst) {
+ int64_t n_elems = ggml_nelements(tensor);
+ int64_t groups = n_elems / QK8_0;
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
+
+ uint8_t* quant_offset = (uint8_t*)dst;
+ uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
+
+ for (int i = 0; i < groups; i++) {
+ const block_q8_0* group =
+ (const block_q8_0*)((const char*)src + i * sizeof(block_q8_0));
+ *scale_offset = group->d;
+ scale_offset++;
+ size_t group_quant_size = QK8_0 * sizeof(uint8_t);
+ memcpy(quant_offset, group->qs, group_quant_size);
+ quant_offset += group_quant_size;
+ }
+}
+
+/**
+ * @brief Transform CANN processed data back into quantized Q8.0 format.
+ *
+ * This function transforms CANN processed data back into quantized Q8.0 format.
+ * It reverses the transformation performed by
+ * ggml_backend_cann_transform_q8_0(), converting the data back into its
+ * original quantized form.
+ *
+ * @param tensor Pointer to the tensor information.
+ * @param src Pointer to the source buffer containing transformed data.
+ * @param dst Pointer to the destination buffer where the Q8.0 formatted data
+ * will be stored.
+ */
+GGML_CALL static void ggml_backend_cann_transform_back_q8_0(
+ const ggml_tensor* tensor, const void* src, void* dst) {
+ int64_t n_elems = ggml_nelements(tensor);
+ int64_t groups = n_elems / QK8_0;
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
+
+ const uint8_t* quant_offset = (const uint8_t*)src;
+ const uint16_t* scale_offset =
+ (const uint16_t*)((const char*)src + quant_bytes);
+
+ for (int i = 0; i < groups; i++) {
+ block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0));
+ group->d = *scale_offset;
+ scale_offset++;
+ size_t group_quant_size = QK8_0 * sizeof(uint8_t);
+ memcpy(group->qs, quant_offset, group_quant_size);
+ quant_offset += group_quant_size;
+ }
+}
+
+/**
+ * @brief Transform tensor data based on its type for CANN processing.
+ *
+ * This function transforms tensor data based on its quantization type for CANN
+ * processing. It dispatches the transformation based on the tensor's type to
+ * specialized functions handling Q4.0 and Q8.0 formats.
+ *
+ * @param tensor Pointer to the tensor information.
+ * @param src Pointer to the source data to be transformed.
+ * @param dst Pointer to the destination buffer where transformed data will be
+ * stored.
+ */
+GGML_CALL static void ggml_backend_cann_transform(ggml_tensor* tensor,
+ const void* src, void* dst) {
+ switch (tensor->type) {
+ case GGML_TYPE_Q4_0:
+ ggml_backend_cann_transform_q4_0(tensor, src, dst);
+ break;
+ case GGML_TYPE_Q8_0:
+ ggml_backend_cann_transform_q8_0(tensor, src, dst);
+ break;
+ default:
+ break;
+ }
+}
+
+/**
+ * @brief Transform CANN processed data back into tensor data based on its type.
+ *
+ * This function transforms CANN processed data back into tensor data based on
+ * its quantization type for Q4.0 and Q8.0 formats. It dispatches the
+ * transformation based on the tensor's type to specialized functions.
+ *
+ * @param tensor Pointer to the tensor information.
+ * @param src Pointer to the source data containing CANN processed data.
+ * @param dst Pointer to the destination buffer where transformed tensor data
+ * will be stored.
+ */
+GGML_CALL static void ggml_backend_cann_transform_back(
+ const ggml_tensor* tensor, void* src, void* dst) {
+ switch (tensor->type) {
+ case GGML_TYPE_Q4_0:
+ ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
+ break;
+ case GGML_TYPE_Q8_0:
+ ggml_backend_cann_transform_back_q8_0(tensor, src, dst);
+ break;
+ default:
+ break;
+ }
+}
+
+/**
+ * @brief Check if transformation is needed for a given tensor type.
+ *
+ * This function checks if transformation is needed for a given tensor type
+ * to prepare data for CANN processing.
+ *
+ * @param type The tensor type to check.
+ * @return true if transformation is needed, false otherwise.
+ */
+GGML_CALL static bool need_transform(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q8_0:
+ return true;
+ default:
+ return false;
+ }
+}
+
+/**
+ * @brief Initialize a tensor using data from a CANN buffer.
+ *
+ * This function initializes a tensor using data from a CANN buffer.
+ * It handles special cases such as views and quantization.
+ *
+ * @param buffer The CANN buffer from which to initialize the tensor.
+ * @param tensor Pointer to the tensor to be initialized.
+ */
+GGML_CALL static void ggml_backend_cann_buffer_init_tensor(
+ ggml_backend_buffer_t buffer, ggml_tensor* tensor) {
+ if (tensor->view_src != NULL && tensor->view_offs == 0) {
+ GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
+ return;
+ }
+
+ // TODO: can backend doesn't support quantized yet. Just leave the code
+ // here.
+ if (ggml_is_quantized(tensor->type)) {
+ // Initialize padding to 0 to avoid possible NaN values
+ size_t original_size = ggml_nbytes(tensor);
+ size_t padded_size =
+ ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
+
+ if (padded_size > original_size && tensor->view_src == nullptr) {
+ size_t memset_size = padded_size - original_size;
+ ACL_CHECK(aclrtMemset((char*)tensor->data + original_size,
+ memset_size, 0, memset_size));
+ }
+ }
+}
+
+// TODO: need handle tensor which has paddings.
+/**
+ * @brief Set tensor data in a CANN buffer.
+ *
+ * This function sets tensor data in a CANN buffer, handling transformations
+ * if needed based on the tensor's type.
+ *
+ * @param buffer The CANN buffer where the tensor data will be set.
+ * @param tensor Pointer to the tensor whose data will be set.
+ * @param data Pointer to the source data to be copied into the tensor.
+ * @param offset Offset in the source data from where to start copying.
+ * @param size Size of the data to be copied, in bytes.
+ */
+GGML_CALL static void ggml_backend_cann_buffer_set_tensor(
+ ggml_backend_buffer_t buffer, ggml_tensor* tensor, const void* data,
+ size_t offset, size_t size) {
+ // GGML_ASSERT(size == ggml_nbytes(tensor));
+ ggml_backend_cann_buffer_context* ctx =
+ (ggml_backend_cann_buffer_context*)buffer->context;
+
+ ggml_cann_set_device(ctx->device);
+ // TODO: refer to cann(#6017), it use thread's default stream.
+ // For acl, synchronous functions use this default stream.
+ // Why aclrtSynchronizeDevice?
+
+ if (!need_transform(tensor->type)) {
+ ACL_CHECK(aclrtMemcpy(tensor->data, size, (const char*)data + offset,
+ size, ACL_MEMCPY_HOST_TO_DEVICE));
+ } else {
+ void* transform_buffer = malloc(size);
+ ggml_backend_cann_transform(tensor, (const char*)data + offset,
+ transform_buffer);
+
+#ifndef NDEBUG
+ void* check_buffer = malloc(size);
+ ggml_backend_cann_transform_back(tensor, transform_buffer,
+ check_buffer);
+ GGML_ASSERT(memcmp((const char*)data + offset, check_buffer, size) ==
+ 0);
+ free(check_buffer);
+#endif
+ ACL_CHECK(aclrtMemcpy(tensor->data, size, transform_buffer, size,
+ ACL_MEMCPY_HOST_TO_DEVICE));
+ free(transform_buffer);
+ }
+}
+
+/**
+ * @brief Get tensor data from a CANN buffer.
+ *
+ * This function retrieves tensor data from a CANN buffer, handling
+ * transformations if needed based on the tensor's type.
+ *
+ * @param buffer The CANN buffer from which to retrieve tensor data.
+ * @param tensor Pointer to the tensor whose data will be retrieved.
+ * @param data Pointer to the destination buffer where the tensor data will be
+ * copied.
+ * @param offset Offset in the destination buffer where to start copying.
+ * @param size Size of the data to be copied, in bytes.
+ */
+GGML_CALL static void ggml_backend_cann_buffer_get_tensor(
+ ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data,
+ size_t offset, size_t size) {
+ GGML_ASSERT(size == ggml_nbytes(tensor));
+ ggml_backend_cann_buffer_context* ctx =
+ (ggml_backend_cann_buffer_context*)buffer->context;
+
+ ggml_cann_set_device(ctx->device);
+
+ if (!need_transform(tensor->type)) {
+ ACL_CHECK(aclrtMemcpy((char*)data + offset, size, tensor->data, size,
+ ACL_MEMCPY_DEVICE_TO_HOST));
+ } else {
+ void* transform_buffer = malloc(size);
+ ACL_CHECK(aclrtMemcpy(transform_buffer, size, tensor->data, size,
+ ACL_MEMCPY_DEVICE_TO_HOST));
+ ggml_backend_cann_transform_back(tensor, transform_buffer,
+ (char*)data + offset);
+ free(transform_buffer);
+ }
+}
+
+/**
+ * @brief Copy tensor data between CANN buffers if possible.
+ *
+ * This function copies tensor data between CANN buffers if the source and
+ * destination buffers are CANN buffers and they meet the necessary conditions
+ * (same device or devices can access each other).
+ *
+ * @param buffer The destination CANN buffer where the tensor data will be
+ * copied.
+ * @param src Pointer to the source tensor whose data will be copied.
+ * @param dst Pointer to the destination tensor where the data will be copied.
+ * @return true if the copy operation succeeded, false otherwise.
+ */
+GGML_CALL static bool ggml_backend_cann_buffer_cpy_tensor(
+ ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) {
+ if (ggml_backend_buffer_is_cann(src->buffer)) {
+ ggml_backend_cann_buffer_context* src_ctx =
+ (ggml_backend_cann_buffer_context*)src->buffer->context;
+ ggml_backend_cann_buffer_context* dst_ctx =
+ (ggml_backend_cann_buffer_context*)buffer->context;
+
+ size_t memcpy_size = ggml_nbytes(src);
+ // Same device.
+ if (src_ctx->device == dst_ctx->device) {
+ ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
+ (const char*)src->data, memcpy_size,
+ ACL_MEMCPY_DEVICE_TO_DEVICE));
+ return true;
+ } else {
+ // Different device but can access by peer.
+ int32_t canAccessPeer = 0;
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
+ dst_ctx->device));
+ if (canAccessPeer) {
+ ggml_cann_set_device(src_ctx->device);
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
+ ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
+ (const char*)src->data, memcpy_size,
+ ACL_MEMCPY_DEVICE_TO_DEVICE));
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+/**
+ * @brief Clear a CANN buffer by setting all its memory to a specified value.
+ *
+ * This function clears a CANN buffer by setting all its memory to a specified
+ * value.
+ *
+ * @param buffer The CANN buffer to be cleared.
+ * @param value The value to which each byte in the buffer will be set.
+ */
+GGML_CALL static void ggml_backend_cann_buffer_clear(
+ ggml_backend_buffer_t buffer, uint8_t value) {
+ ggml_backend_cann_buffer_context* ctx =
+ (ggml_backend_cann_buffer_context*)buffer->context;
+
+ ggml_cann_set_device(ctx->device);
+ ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
+}
+
+/**
+ * @brief Interface for a CANN buffer in the backend.
+ *
+ * This structure defines function pointers to operations that can be performed
+ * on a CANN buffer within the backend.
+ */
+static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
+ /* .get_name = */ ggml_backend_cann_buffer_get_name,
+ /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_cann_buffer_get_base,
+ /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
+ /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
+ /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
+ /* .clear = */ ggml_backend_cann_buffer_clear,
+ /* .reset = */ NULL,
+};
+
+// cann buffer type
+/**
+ * @brief Structure representing context information for a specific backend
+ * buffer type.
+ */
+struct ggml_backend_cann_buffer_type_context {
+ int32_t
+ device; /**< Device identifier associated with the buffer context. */
+ std::string name; /**< Name associated with the buffer context. */
+};
+
+/**
+ * @brief Retrieves the name associated with a CANN buffer type.
+ *
+ * This function returns the descriptive name associated with the specified
+ * CANN buffer type context.
+ *
+ * @param buft Pointer to the buffer type context.
+ * @return Const pointer to the C-style string containing the name.
+ */
+GGML_CALL static const char* ggml_backend_cann_buffer_type_name(
+ ggml_backend_buffer_type_t buft) {
+ return "CANN";
+
+ GGML_UNUSED(buft);
+}
+
+/**
+ * @brief Allocates a new CANN buffer of the specified type and size.
+ *
+ * This function allocates a new CANN buffer on the specified device with the
+ * given size.
+ *
+ * @param buft Pointer to the buffer type context.
+ * @param size Size in bytes of the buffer to allocate.
+ * @return Pointer to the allocated buffer, or nullptr if allocation fails.
+ */
+GGML_CALL static ggml_backend_buffer_t
+ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
+ size_t size) {
+ ggml_backend_cann_buffer_type_context* buft_ctx =
+ (ggml_backend_cann_buffer_type_context*)buft->context;
+
+ ggml_cann_set_device(buft_ctx->device);
+
+ size = std::max(size, (size_t)1);
+
+ void* dev_ptr;
+ aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
+ if (err != ACL_SUCCESS) {
+ GGML_CANN_LOG_ERROR(
+ "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
+ __func__, size / 1024.0 / 1024.0, buft_ctx->device,
+ aclGetRecentErrMsg());
+ return nullptr;
+ }
+
+ ggml_backend_cann_buffer_context* ctx =
+ new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
+
+ return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
+ ctx, size);
+}
+
+/**
+ * @brief Retrieves the memory alignment requirement for CANN buffers of this
+ * type.
+ *
+ * This function returns the alignment requirement in bytes for memory allocated
+ * by the CANN buffer type.
+ *
+ * @param buft Pointer to the buffer type context (unused in this
+ * implementation).
+ * @return The alignment requirement in bytes (fixed at 128 bytes for CANN
+ * buffers).
+ */
+GGML_CALL static size_t ggml_backend_cann_buffer_type_get_alignment(
+ ggml_backend_buffer_type_t buft) {
+ return 128;
+
+ GGML_UNUSED(buft);
+}
+
+/**
+ * @brief Calculates the allocation size required for a tensor in a CANN buffer.
+ *
+ * Computes the total allocation size needed for storing the tensor's data in a
+ * CANN buffer, considering any necessary padding or adjustments for quantized
+ * types.
+ *
+ * @param buft Pointer to the buffer type context (unused in this
+ * implementation).
+ * @param tensor Pointer to the tensor for which the allocation size is
+ * calculated.
+ * @return The total allocation size in bytes required for the tensor in the
+ * CANN buffer.
+ */
+GGML_CALL static size_t ggml_backend_cann_buffer_type_get_alloc_size(
+ ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
+ size_t size = ggml_nbytes(tensor);
+ int64_t ne0 = tensor->ne[0];
+
+ // last line must bigger than 32, because every single op deal at
+ // least 32 bytes.
+ // TODO: quantized type?
+ // int64_t line_size = ne0 * ggml_element_size(tensor);
+ // int64_t line_size_align_32 = (line_size + 31) & ~31;
+ // size += (line_size_align_32 - line_size);
+
+ // TODO: not support quantized yet.
+ // TODO: consider un-continue tensor.
+ if (ggml_is_quantized(tensor->type)) {
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ size += ggml_row_size(
+ tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+ }
+
+ return size;
+
+ GGML_UNUSED(buft);
+}
+
+/**
+ * @brief Interface for managing CANN buffer types in the GGML backend.
+ *
+ * Provides function pointers for allocating, querying properties, and managing
+ * memory for CANN buffer types in the GGML backend.
+ */
+static ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
+ /* .get_name = */ ggml_backend_cann_buffer_type_name,
+ /* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
+ /* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
+ /* .is_host = */ NULL,
+};
+
+/**
+ * @brief Retrieves the CANN buffer type for a specified device.
+ *
+ * This function initializes and returns the buffer type interface associated
+ * with the given device. It ensures thread-safe access using a mutex.
+ *
+ * @param device The device index for which to retrieve the buffer type.
+ * @return A pointer to the buffer type interface for the specified device, or
+ * nullptr if the device index is out of range.
+ */
+GGML_CALL ggml_backend_buffer_type_t
+ggml_backend_cann_buffer_type(int32_t device) {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+
+ if (device >= ggml_backend_cann_get_device_count()) {
+ return nullptr;
+ }
+
+ static ggml_backend_buffer_type
+ ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
+
+ static bool ggml_backend_cann_buffer_type_initialized = false;
+
+ if (!ggml_backend_cann_buffer_type_initialized) {
+ for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
+ ggml_backend_cann_buffer_types[i] = {
+ /* .iface = */ ggml_backend_cann_buffer_type_interface,
+ /* .context = */
+ new ggml_backend_cann_buffer_type_context{
+ i, "CANN" + std::to_string(i)},
+ };
+ }
+ ggml_backend_cann_buffer_type_initialized = true;
+ }
+
+ return &ggml_backend_cann_buffer_types[device];
+}
+
+/**
+ * @brief Computes the forward operation for a given tensor using CANN
+ * operations.
+ *
+ * This function selects the appropriate CANN operation based on the type of
+ * operation specified in the tensor and performs the computation.
+ *
+ * @param ctx The CANN context containing necessary resources and
+ * configurations.
+ * @param dst The destination tensor where the result of the computation will be
+ * stored.
+ * @return true if the computation was successful; false otherwise.
+ */
+static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
+ struct ggml_tensor* dst) {
+ switch (dst->op) {
+ case GGML_OP_REPEAT:
+ ggml_cann_repeat(ctx, dst);
+ break;
+ case GGML_OP_GET_ROWS:
+ ggml_cann_get_rows(ctx, dst);
+ break;
+ case GGML_OP_DUP:
+ ggml_cann_dup(ctx, dst);
+ break;
+ case GGML_OP_ADD:
+ ggml_cann_add(ctx, dst);
+ break;
+ case GGML_OP_ACC:
+ ggml_cann_acc(ctx, dst);
+ break;
+ case GGML_OP_MUL:
+ ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
+ break;
+ case GGML_OP_DIV:
+ ggml_cann_mul_div<aclnnDivGetWorkspaceSize, aclnnDiv>(ctx, dst);
+ break;
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(dst)) {
+ case GGML_UNARY_OP_GELU:
+ ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
+ ctx, dst);
+ break;
+ case GGML_UNARY_OP_SILU:
+ ggml_cann_activation<aclnnSiluGetWorkspaceSize, aclnnSilu>(
+ ctx, dst);
+ break;
+ // TODO: Use faster gelu??
+ case GGML_UNARY_OP_GELU_QUICK:
+ ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
+ ctx, dst);
+ break;
+ case GGML_UNARY_OP_TANH:
+ ggml_cann_activation<aclnnTanhGetWorkspaceSize, aclnnTanh>(
+ ctx, dst);
+ break;
+ case GGML_UNARY_OP_RELU:
+ ggml_cann_activation<aclnnReluGetWorkspaceSize, aclnnRelu>(
+ ctx, dst);
+ break;
+ case GGML_UNARY_OP_HARDSIGMOID:
+ ggml_cann_activation<aclnnHardsigmoidGetWorkspaceSize,
+ aclnnHardsigmoid>(ctx, dst);
+ break;
+ case GGML_UNARY_OP_HARDSWISH:
+ ggml_cann_activation<aclnnHardswishGetWorkspaceSize,
+ aclnnHardswish>(ctx, dst);
+ break;
+ default:
+ return false;
+ }
+ break;
+ case GGML_OP_NORM:
+ ggml_cann_norm(ctx, dst);
+ break;
+ case GGML_OP_GROUP_NORM:
+ ggml_cann_group_norm(ctx, dst);
+ break;
+ case GGML_OP_CONCAT:
+ ggml_cann_concat(ctx, dst);
+ break;
+ case GGML_OP_UPSCALE:
+ ggml_cann_upsample_nearest2d(ctx, dst);
+ break;
+ case GGML_OP_PAD:
+ ggml_cann_pad(ctx, dst);
+ break;
+ case GGML_OP_ARANGE:
+ ggml_cann_arange(ctx, dst);
+ break;
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ ggml_cann_timestep_embedding(ctx, dst);
+ break;
+ case GGML_OP_LEAKY_RELU:
+ ggml_cann_leaky_relu(ctx, dst);
+ break;
+ case GGML_OP_RMS_NORM:
+ ggml_cann_rms_norm(ctx, dst);
+ break;
+ case GGML_OP_MUL_MAT:
+ ggml_cann_mul_mat(ctx, dst);
+ break;
+ case GGML_OP_MUL_MAT_ID:
+ return false;
+ case GGML_OP_SCALE:
+ ggml_cann_scale(ctx, dst);
+ break;
+ case GGML_OP_SQR:
+ ggml_cann_sqr(ctx, dst);
+ break;
+ case GGML_OP_CLAMP:
+ ggml_cann_clamp(ctx, dst);
+ break;
+ case GGML_OP_CPY:
+ ggml_cann_cpy(ctx, dst);
+ break;
+ case GGML_OP_CONT:
+ ggml_cann_dup(ctx, dst);
+ break;
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ break;
+ case GGML_OP_DIAG_MASK_INF:
+ ggml_cann_diag_mask(ctx, dst, -INFINITY);
+ break;
+ case GGML_OP_SOFT_MAX:
+ ggml_cann_softmax(ctx, dst);
+ break;
+ case GGML_OP_ROPE:
+ ggml_cann_rope(ctx, dst);
+ break;
+ case GGML_OP_IM2COL:
+ ggml_cann_im2col(ctx, dst);
+ break;
+ case GGML_OP_POOL_2D:
+ ggml_cann_pool2d(ctx, dst);
+ break;
+ case GGML_OP_SUM_ROWS:
+ ggml_cann_sum_rows(ctx, dst);
+ break;
+ case GGML_OP_ARGSORT:
+ ggml_cann_argsort(ctx, dst);
+ break;
+ default:
+ return false;
+ }
+
+ return true;
+}
+
+// backend
+/**
+ * @brief Retrieves the name associated with the CANN backend.
+ *
+ * This function returns the name assigned to the CANN backend, which is stored
+ * in the context of the provided backend structure.
+ *
+ * @param backend Pointer to the CANN backend structure.
+ * @return A pointer to a constant string representing the backend name.
+ */
+GGML_CALL static const char* ggml_backend_cann_name(ggml_backend_t backend) {
+ ggml_backend_cann_context* cann_ctx =
+ (ggml_backend_cann_context*)backend->context;
+
+ return cann_ctx->name.c_str();
+}
+
+/**
+ * @brief Frees resources associated with the CANN backend.
+ *
+ * This function releases resources associated with the CANN backend context
+ * and resets the device associated with the backend to its initial state.
+ *
+ * @param backend Pointer to the CANN backend structure to be freed.
+ */
+GGML_CALL static void ggml_backend_cann_free(ggml_backend_t backend) {
+ ggml_backend_cann_context* cann_ctx =
+ (ggml_backend_cann_context*)backend->context;
+ ACL_CHECK(aclrtSynchronizeDevice());
+ ACL_CHECK(aclrtResetDevice(cann_ctx->device));
+
+ // finalize when last backend freed.
+ if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
+ ACL_CHECK(aclFinalize());
+ }
+
+ delete cann_ctx;
+ delete backend;
+}
+
+/**
+ * @brief Retrieves the default buffer type associated with the CANN backend.
+ *
+ * This function returns the buffer type specific to the device associated
+ * with the CANN backend. It is used to allocate buffers for computations
+ * performed by the backend.
+ *
+ * @param backend Pointer to the CANN backend structure.
+ * @return Pointer to the buffer type structure for the CANN backend.
+ */
+GGML_CALL static ggml_backend_buffer_type_t
+ggml_backend_cann_get_default_buffer_type(ggml_backend_t backend) {
+ ggml_backend_cann_context* cann_ctx =
+ (ggml_backend_cann_context*)backend->context;
+
+ return ggml_backend_cann_buffer_type(cann_ctx->device);
+}
+
+/**
+ * @brief Sets tensor data asynchronously in the CANN backend.
+ *
+ * This function asynchronously sets tensor data in the CANN backend. Depending
+ * on the tensor type, it may perform data transformations before copying data
+ * to the device.
+ *
+ * @param backend Pointer to the CANN backend structure.
+ * @param tensor Pointer to the tensor structure to set data for.
+ * @param data Pointer to the host data to copy to the tensor.
+ * @param offset Offset in bytes within the host data.
+ * @param size Size of the data to copy in bytes.
+ */
+GGML_CALL static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
+ ggml_tensor* tensor,
+ const void* data,
+ size_t offset,
+ size_t size) {
+ ggml_backend_cann_context* cann_ctx =
+ (ggml_backend_cann_context*)backend->context;
+
+ if (!need_transform(tensor->type)) {
+ ACL_CHECK(aclrtMemcpyAsync(
+ tensor->data, size, (const char*)data + offset, size,
+ ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
+ } else {
+ void* transform_buffer = malloc(size);
+ ggml_backend_cann_transform(tensor, (const char*)data + offset,
+ transform_buffer);
+
+#ifndef NDEBUG
+ void* check_buffer = malloc(size);
+ ggml_backend_cann_transform_back(tensor, transform_buffer,
+ check_buffer);
+ GGML_ASSERT(memcmp((const char*)data + offset, check_buffer, size));
+ free(check_buffer);
+#endif
+ ACL_CHECK(aclrtMemcpyAsync(tensor->data, size, transform_buffer, size,
+ ACL_MEMCPY_HOST_TO_DEVICE,
+ cann_ctx->stream()));
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
+ free(transform_buffer);
+ }
+}
+
+GGML_CALL static void ggml_backend_cann_get_tensor_async(
+ ggml_backend_t backend, const ggml_tensor* tensor, void* data,
+ size_t offset, size_t size) {
+ ggml_backend_cann_context* cann_ctx =
+ (ggml_backend_cann_context*)backend->context;
+ ggml_backend_buffer_t buf =
+ tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+ GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
+ "unsupported buffer type");
+
+ if (!need_transform(tensor->type)) {
+ ACL_CHECK(aclrtMemcpyAsync((char*)data + offset, size, tensor->data,
+ size, ACL_MEMCPY_DEVICE_TO_HOST,
+ cann_ctx->stream()));
+ } else {
+ void* transform_buffer = malloc(size);
+ ACL_CHECK(aclrtMemcpyAsync(transform_buffer, size, tensor->data, size,
+ ACL_MEMCPY_DEVICE_TO_HOST,
+ cann_ctx->stream()));
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
+ ggml_backend_cann_transform_back(tensor, transform_buffer,
+ (char*)data + offset);
+ free(transform_buffer);
+ }
+}
+
+/**
+ * @brief Asynchronously copies tensor data between CANN backends.
+ *
+ * This function copies tensor data asynchronously between two CANN backends. It
+ * checks if both tensors reside in CANN buffers and whether the devices support
+ * peer-to-peer access for direct copying. If not, it returns false.
+ *
+ * @param backend_src Pointer to the source CANN backend structure.
+ * @param backend_dst Pointer to the destination CANN backend structure.
+ * @param src Pointer to the source tensor to copy data from.
+ * @param dst Pointer to the destination tensor to copy data to.
+ * @return true if the copy operation succeeds, false otherwise.
+ */
+GGML_CALL static bool ggml_backend_cann_cpy_tensor_async(
+ ggml_backend_t backend_src, ggml_backend_t backend_dst,
+ const ggml_tensor* src, ggml_tensor* dst) {
+ GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
+ ggml_backend_is_cann(backend_dst));
+
+ if (!ggml_backend_buffer_is_cann(src->buffer) ||
+ !ggml_backend_buffer_is_cann(dst->buffer)) {
+ return false;
+ }
+
+ ggml_backend_buffer_t buf_src =
+ src->view_src ? src->view_src->buffer : src->buffer;
+ ggml_backend_buffer_t buf_dst =
+ dst->view_src ? dst->view_src->buffer : dst->buffer;
+
+ ggml_backend_cann_context* cann_ctx_src =
+ (ggml_backend_cann_context*)backend_src->context;
+ ggml_backend_cann_context* cann_ctx_dst =
+ (ggml_backend_cann_context*)backend_dst->context;
+
+ size_t copy_size = ggml_nbytes(dst);
+ if (backend_src != backend_dst) {
+ ggml_backend_cann_buffer_context* buf_ctx_src =
+ (ggml_backend_cann_buffer_context*)buf_src->context;
+ ggml_backend_cann_buffer_context* buf_ctx_dst =
+ (ggml_backend_cann_buffer_context*)buf_dst->context;
+
+ GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
+ GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
+
+ int32_t canAccessPeer = 0;
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device,
+ cann_ctx_dst->device));
+ if (!canAccessPeer) {
+ return false;
+ }
+
+ ggml_cann_set_device(cann_ctx_src->device);
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
+ ACL_MEMCPY_DEVICE_TO_DEVICE,
+ cann_ctx_dst->stream()));
+
+ // record event on src stream
+ if (!cann_ctx_src->copy_event) {
+ ACL_CHECK(aclrtCreateEvent(&cann_ctx_src->copy_event));
+ }
+
+ ACL_CHECK(
+ aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream()));
+
+ // wait on dst stream for the copy to complete
+ ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(),
+ cann_ctx_src->copy_event));
+ } else {
+ // src and dst are on the same backend
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
+ ACL_MEMCPY_DEVICE_TO_DEVICE,
+ cann_ctx_dst->stream()));
+ }
+
+ return true;
+}
+
+/**
+ * @brief Synchronizes a CANN backend.
+ *
+ * This function synchronizes the specified CANN backend by waiting for all
+ * operations in its associated stream to complete.
+ *
+ * @param backend Pointer to the CANN backend structure to synchronize.
+ */
+GGML_CALL static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
+ ggml_backend_cann_context* cann_ctx =
+ (ggml_backend_cann_context*)backend->context;
+
+ ggml_cann_set_device(cann_ctx->device);
+
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
+}
+
+/**
+ * @brief Computes a computational graph using a CANN backend.
+ *
+ * This function computes the operations defined in the computational graph
+ * using the specified CANN backend.
+ *
+ * @param backend Pointer to the CANN backend structure to use for computation.
+ * @param cgraph Pointer to the computational graph structure containing nodes
+ * representing operations to be computed.
+ * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
+ * completes successfully, otherwise an appropriate error status.
+ */
+GGML_CALL static enum ggml_status ggml_backend_cann_graph_compute(
+ ggml_backend_t backend, ggml_cgraph* cgraph) {
+ ggml_backend_cann_context* cann_ctx =
+ (ggml_backend_cann_context*)backend->context;
+
+ ggml_cann_set_device(cann_ctx->device);
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ ggml_tensor* node = cgraph->nodes[i];
+
+ if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
+ continue;
+ }
+
+ bool ok = ggml_cann_compute_forward(*cann_ctx, node);
+
+ if (!ok) {
+ GGML_CANN_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
+ node->name, ggml_op_name(node->op));
+ }
+ GGML_ASSERT(ok);
+ }
+
+ return GGML_STATUS_SUCCESS;
+}
+
+/**
+ * @brief Checks if the CANN backend supports a specific operation.
+ *
+ * This function checks whether the specified operation is supported by the
+ * CANN backend.
+ *
+ * @param backend Pointer to the CANN backend structure to check support for
+ * the operation.
+ * @param op Pointer to the tensor representing the operation to check.
+ * @return bool Returns true if the operation is supported by the backend,
+ * otherwise false.
+ */
+GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
+ const ggml_tensor* op) {
+ switch (op->op) {
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(op)) {
+ case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_HARDSIGMOID:
+ case GGML_UNARY_OP_HARDSWISH:
+ case GGML_UNARY_OP_GELU_QUICK:
+ case GGML_UNARY_OP_TANH:
+ return true;
+ default:
+ return false;
+ }
+ case GGML_OP_MUL_MAT: {
+ switch (op->src[0]->type) {
+ // case GGML_TYPE_Q4_0:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ case GGML_TYPE_Q8_0:
+ return true;
+ default:
+ return false;
+ }
+ }
+ case GGML_OP_MUL_MAT_ID:
+ return false;
+ // embedding
+ case GGML_OP_GET_ROWS: {
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q8_0:
+ return true;
+ default:
+ return false;
+ }
+ } break;
+ case GGML_OP_CPY: {
+ switch (op->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_Q8_0:
+ return true;
+ default:
+ return false;
+ }
+ }
+ case GGML_OP_DUP:
+ case GGML_OP_REPEAT:
+ case GGML_OP_CONCAT:
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_NORM:
+ case GGML_OP_ADD:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_SCALE:
+ case GGML_OP_SQR:
+ case GGML_OP_CLAMP:
+ case GGML_OP_CONT:
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_SOFT_MAX:
+ case GGML_OP_ROPE:
+ case GGML_OP_IM2COL:
+ case GGML_OP_POOL_2D:
+ case GGML_OP_SUM_ROWS:
+ case GGML_OP_ARGSORT:
+ case GGML_OP_ACC:
+ case GGML_OP_GROUP_NORM:
+ case GGML_OP_UPSCALE:
+ case GGML_OP_PAD:
+ case GGML_OP_ARANGE:
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ case GGML_OP_LEAKY_RELU:
+ return true;
+ default:
+ return false;
+ }
+
+ GGML_UNUSED(backend);
+}
+
+/**
+ * @brief Checks if the backend buffer type is associated with the CANN backend.
+ *
+ * This function checks whether the provided backend buffer type is associated
+ * with the CANN backend based on the comparison of its name retrieval function
+ * pointer.
+ *
+ * @param buft Pointer to the backend buffer type to check.
+ * @return bool Returns true if the buffer type is associated with the CANN
+ * backend, otherwise false.
+ */
+static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
+ return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
+}
+
+/**
+ * @brief Checks if the CANN backend supports a specific backend buffer type.
+ *
+ * This function determines whether the CANN backend supports the given backend
+ * buffer type by comparing the device context of the backend and buffer type.
+ * It returns true if the device associated with the buffer type matches the
+ * device associated with the backend.
+ *
+ * @param backend Pointer to the CANN backend.
+ * @param buft Pointer to the backend buffer type to check.
+ * @return bool Returns true if the CANN backend supports the buffer type,
+ * otherwise false.
+ */
+GGML_CALL static bool ggml_backend_cann_supports_buft(
+ ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+ return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
+
+ GGML_UNUSED(backend);
+}
+
+/**
+ * @brief Determines if a tensor operation should be offloaded to the CANN
+ * backend.
+ *
+ * This function checks if a given tensor operation should be offloaded to the
+ * CANN backend based on the operation type and the size of the tensor. It
+ * returns true if the second dimension (ne[1]) of the tensor is greater than or
+ * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
+ *
+ * @param backend Pointer to the CANN backend.
+ * @param op Pointer to the tensor operation to check.
+ * @return bool Returns true if the operation should be offloaded, otherwise
+ * false.
+ */
+GGML_CALL static bool ggml_backend_cann_offload_op(ggml_backend_t backend,
+ const ggml_tensor* op) {
+ const int min_batch_size = 32;
+ GGML_UNUSED(backend);
+
+ return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
+}
+
+/**
+ * @brief Creates a new event for the CANN backend.
+ *
+ * This function initializes a new event for the CANN backend by setting the
+ * device and creating an ACL runtime event. The created event is then wrapped
+ * in a ggml_backend_event structure and returned.
+ *
+ * @param backend Pointer to the CANN backend.
+ * @return ggml_backend_event_t Returns a pointer to the new event structure.
+ */
+static ggml_backend_event_t ggml_backend_cann_event_new(
+ ggml_backend_t backend) {
+ ggml_backend_cann_context* cann_ctx =
+ (ggml_backend_cann_context*)backend->context;
+
+ ggml_cann_set_device(cann_ctx->device);
+
+ aclrtEvent event;
+ ACL_CHECK(aclrtCreateEvent(&event));
+
+ return new ggml_backend_event{
+ /* .backend = */ backend,
+ /* .context = */ event,
+ };
+}
+
+/**
+ * @brief Frees a CANN backend event.
+ *
+ * This function destroys the ACL runtime event associated with the given CANN
+ * backend event and then deletes the event structure itself.
+ *
+ * @param event Pointer to the event structure to be freed.
+ */
+static void ggml_backend_cann_event_free(ggml_backend_event_t event) {
+ ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
+
+ delete event;
+}
+
+/**
+ * @brief Records an event on the CANN backend stream.
+ *
+ * This function records the given event on the ACL runtime stream associated
+ * with the backend context.
+ *
+ * @param event Pointer to the event structure to be recorded.
+ */
+static void ggml_backend_cann_event_record(ggml_backend_event_t event) {
+ ggml_backend_cann_context* cann_ctx =
+ (ggml_backend_cann_context*)event->backend->context;
+
+ ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
+}
+
+/**
+ * @brief Waits for a recorded event to complete on the CANN backend stream.
+ *
+ * This function makes the given backend wait for the event to complete on its
+ * ACL runtime stream.
+ *
+ * @param backend Pointer to the backend structure.
+ * @param event Pointer to the event structure that the backend needs to wait
+ * for.
+ */
+static void ggml_backend_cann_event_wait(ggml_backend_t backend,
+ ggml_backend_event_t event) {
+ ggml_backend_cann_context* cann_ctx =
+ (ggml_backend_cann_context*)backend->context;
+
+ if (ggml_backend_is_cann(event->backend)) {
+ ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
+ (aclrtEvent)event->context));
+ } else {
+ GGML_ASSERT(false);
+ }
+}
+
+/**
+ * @brief Synchronizes the given event on the CANN backend.
+ *
+ * This function waits for the specified event to complete on the ACL runtime.
+ *
+ * @param event Pointer to the event structure to be synchronized.
+ */
+static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) {
+ ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
+}
+
+/**
+ * @brief Structure defining the interface for the CANN backend.
+ *
+ * This structure contains function pointers for various operations
+ * supported by the CANN backend, including name retrieval, memory
+ * management, tensor operations, synchronization, and event handling.
+ */
+static ggml_backend_i ggml_backend_cann_interface = {
+ /* .get_name = */ ggml_backend_cann_name,
+ /* .free = */ ggml_backend_cann_free,
+ /* .get_default_buffer_type = */ ggml_backend_cann_get_default_buffer_type,
+ /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
+ /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
+ /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
+ /* .synchronize = */ ggml_backend_cann_synchronize,
+ /* .graph_plan_create = */ NULL,
+ /* .graph_plan_free = */ NULL,
+ /* .graph_plan_update = */ NULL,
+ /* .graph_plan_compute = */ NULL,
+ /* .graph_compute = */ ggml_backend_cann_graph_compute,
+ /* .supports_op = */ ggml_backend_cann_supports_op,
+ /* .supports_buft = */ ggml_backend_cann_supports_buft,
+ /* .offload_op = */ ggml_backend_cann_offload_op,
+ /* .event_new = */ ggml_backend_cann_event_new,
+ /* .event_free = */ ggml_backend_cann_event_free,
+ /* .event_record = */ ggml_backend_cann_event_record,
+ /* .event_wait = */ ggml_backend_cann_event_wait,
+ /* .event_synchronize = */ ggml_backend_cann_event_synchronize,
+};
+
+/**
+ * @brief Return the hardcoded GUID for the CANN backend.
+ *
+ * This function returns a static GUID which uniquely identifies the CANN
+ * backend.
+ *
+ * @return A pointer to the static GUID.
+ */
+static ggml_guid_t ggml_backend_cann_guid() {
+ static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
+ 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64};
+ return &guid;
+}
+
+GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device) {
+ aclInit(nullptr);
+ if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
+ GGML_CANN_LOG_ERROR("%s: error: invalid device %d\n", __func__, device);
+ return nullptr;
+ }
+
+ ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device);
+ if (ctx == nullptr) {
+ GGML_CANN_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
+ return nullptr;
+ }
+
+ ggml_backend_t cann_backend =
+ new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
+ /* .interface = */ ggml_backend_cann_interface,
+ /* .context = */ ctx};
+
+ return cann_backend;
+}
+
+GGML_CALL bool ggml_backend_is_cann(ggml_backend_t backend) {
+ return backend != NULL &&
+ ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
+}
+
+GGML_CALL int32_t ggml_backend_cann_get_device_count() {
+ return ggml_cann_info().device_count;
+}
+
+GGML_CALL void ggml_backend_cann_get_device_description(
+ int32_t device, char* description, size_t description_size) {
+ ggml_cann_set_device(device);
+ const char* soc_name = aclrtGetSocName();
+ snprintf(description, description_size, "%s", soc_name);
+}
+
+GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device, size_t* free,
+ size_t* total) {
+ ggml_cann_set_device(device);
+ ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
+}
+
+// backend registry
+/**
+ * @brief Initializes a CANN backend based on the provided parameters.
+ *
+ * This function initializes a CANN backend using the device index and then
+ * initializes the backend using `ggml_backend_cann_init`.
+ *
+ * @param params Parameters for initialization (unused in this implementation).
+ * @param user_data User data containing the device index to initialize the
+ * backend.
+ * @return ggml_backend_t The initialized CANN backend.
+ */
+GGML_CALL static ggml_backend_t ggml_backend_reg_cann_init(const char* params,
+ void* user_data) {
+ ggml_backend_t cann_backend =
+ ggml_backend_cann_init((int)(intptr_t)user_data);
+ return cann_backend;
+
+ GGML_UNUSED(params);
+}
+
+extern "C" GGML_CALL int ggml_backend_cann_reg_devices();
+
+/**
+ * @brief Registers CANN (Ascend) devices as backend options.
+ *
+ * This function initializes ACL, retrieves the number of available CANN
+ * devices, and registers each device as a backend option using
+ * `ggml_backend_register`. Each device is given a unique name based on
+ * `GGML_CANN_NAME` followed by its index.
+ *
+ * @return int The number of CANN devices registered.
+ */
+GGML_CALL int ggml_backend_cann_reg_devices() {
+ uint32_t device_count = ggml_backend_cann_get_device_count();
+ // initialization
+ for (uint32_t i = 0; i < device_count; i++) {
+ char name[128];
+ snprintf(name, sizeof(name), "CANN%d", i);
+ ggml_backend_register(name, ggml_backend_reg_cann_init,
+ ggml_backend_cann_buffer_type(i),
+ (void*)(intptr_t)i);
+ }
+ return device_count;
+}
diff --git a/ggml/src/ggml-cann/.clang-format b/ggml/src/ggml-cann/.clang-format
new file mode 100644
index 00000000..2ad03d73
--- /dev/null
+++ b/ggml/src/ggml-cann/.clang-format
@@ -0,0 +1,168 @@
+---
+Language: Cpp
+# BasedOnStyle: Google
+AccessModifierOffset: -1
+AlignAfterOpenBracket: Align
+AlignConsecutiveMacros: false
+AlignConsecutiveAssignments: false
+AlignConsecutiveDeclarations: false
+AlignEscapedNewlines: Left
+AlignOperands: true
+AlignTrailingComments: true
+AllowAllArgumentsOnNextLine: true
+AllowAllConstructorInitializersOnNextLine: true
+AllowAllParametersOfDeclarationOnNextLine: true
+AllowShortBlocksOnASingleLine: Never
+AllowShortCaseLabelsOnASingleLine: false
+AllowShortFunctionsOnASingleLine: All
+AllowShortLambdasOnASingleLine: All
+AllowShortIfStatementsOnASingleLine: WithoutElse
+AllowShortLoopsOnASingleLine: true
+AlwaysBreakAfterDefinitionReturnType: None
+AlwaysBreakAfterReturnType: None
+AlwaysBreakBeforeMultilineStrings: true
+AlwaysBreakTemplateDeclarations: Yes
+BinPackArguments: true
+BinPackParameters: true
+BraceWrapping:
+ AfterCaseLabel: false
+ AfterClass: false
+ AfterControlStatement: false
+ AfterEnum: false
+ AfterFunction: false
+ AfterNamespace: false
+ AfterObjCDeclaration: false
+ AfterStruct: false
+ AfterUnion: false
+ AfterExternBlock: false
+ BeforeCatch: false
+ BeforeElse: false
+ IndentBraces: false
+ SplitEmptyFunction: true
+ SplitEmptyRecord: true
+ SplitEmptyNamespace: true
+BreakBeforeBinaryOperators: None
+BreakBeforeBraces: Attach
+BreakBeforeInheritanceComma: false
+BreakInheritanceList: BeforeColon
+BreakBeforeTernaryOperators: true
+BreakConstructorInitializersBeforeComma: false
+BreakConstructorInitializers: BeforeColon
+BreakAfterJavaFieldAnnotations: false
+BreakStringLiterals: true
+ColumnLimit: 80
+CommentPragmas: '^ IWYU pragma:'
+CompactNamespaces: false
+ConstructorInitializerAllOnOneLineOrOnePerLine: true
+ConstructorInitializerIndentWidth: 4
+ContinuationIndentWidth: 4
+Cpp11BracedListStyle: true
+DeriveLineEnding: true
+DerivePointerAlignment: true
+DisableFormat: false
+ExperimentalAutoDetectBinPacking: false
+FixNamespaceComments: true
+ForEachMacros:
+ - foreach
+ - Q_FOREACH
+ - BOOST_FOREACH
+IncludeBlocks: Regroup
+IncludeCategories:
+ - Regex: '^<ext/.*\.h>'
+ Priority: 2
+ SortPriority: 0
+ - Regex: '^<.*\.h>'
+ Priority: 1
+ SortPriority: 0
+ - Regex: '^<.*'
+ Priority: 2
+ SortPriority: 0
+ - Regex: '.*'
+ Priority: 3
+ SortPriority: 0
+IncludeIsMainRegex: '([-_](test|unittest))?$'
+IncludeIsMainSourceRegex: ''
+IndentCaseLabels: true
+IndentGotoLabels: true
+IndentPPDirectives: None
+IndentWidth: 4
+IndentWrappedFunctionNames: false
+JavaScriptQuotes: Leave
+JavaScriptWrapImports: true
+KeepEmptyLinesAtTheStartOfBlocks: false
+MacroBlockBegin: ''
+MacroBlockEnd: ''
+MaxEmptyLinesToKeep: 1
+NamespaceIndentation: None
+ObjCBinPackProtocolList: Never
+ObjCBlockIndentWidth: 2
+ObjCSpaceAfterProperty: false
+ObjCSpaceBeforeProtocolList: true
+PenaltyBreakAssignment: 2
+PenaltyBreakBeforeFirstCallParameter: 1
+PenaltyBreakComment: 300
+PenaltyBreakFirstLessLess: 120
+PenaltyBreakString: 1000
+PenaltyBreakTemplateDeclaration: 10
+PenaltyExcessCharacter: 1000000
+PenaltyReturnTypeOnItsOwnLine: 200
+PointerAlignment: Left
+RawStringFormats:
+ - Language: Cpp
+ Delimiters:
+ - cc
+ - CC
+ - cpp
+ - Cpp
+ - CPP
+ - 'c++'
+ - 'C++'
+ CanonicalDelimiter: ''
+ BasedOnStyle: google
+ - Language: TextProto
+ Delimiters:
+ - pb
+ - PB
+ - proto
+ - PROTO
+ EnclosingFunctions:
+ - EqualsProto
+ - EquivToProto
+ - PARSE_PARTIAL_TEXT_PROTO
+ - PARSE_TEST_PROTO
+ - PARSE_TEXT_PROTO
+ - ParseTextOrDie
+ - ParseTextProtoOrDie
+ CanonicalDelimiter: ''
+ BasedOnStyle: google
+ReflowComments: true
+SortIncludes: true
+SortUsingDeclarations: true
+SpaceAfterCStyleCast: false
+SpaceAfterLogicalNot: false
+SpaceAfterTemplateKeyword: true
+SpaceBeforeAssignmentOperators: true
+SpaceBeforeCpp11BracedList: false
+SpaceBeforeCtorInitializerColon: true
+SpaceBeforeInheritanceColon: true
+SpaceBeforeParens: ControlStatements
+SpaceBeforeRangeBasedForLoopColon: true
+SpaceInEmptyBlock: false
+SpaceInEmptyParentheses: false
+SpacesBeforeTrailingComments: 2
+SpacesInAngles: false
+SpacesInConditionalStatement: false
+SpacesInContainerLiterals: true
+SpacesInCStyleCastParentheses: false
+SpacesInParentheses: false
+SpacesInSquareBrackets: false
+SpaceBeforeSquareBrackets: false
+Standard: Auto
+StatementMacros:
+ - Q_UNUSED
+ - QT_REQUIRE_VERSION
+TabWidth: 8
+UseCRLF: false
+UseTab: Never
+...
+
diff --git a/ggml/src/ggml-cann/Doxyfile b/ggml/src/ggml-cann/Doxyfile
new file mode 100644
index 00000000..2b009e8f
--- /dev/null
+++ b/ggml/src/ggml-cann/Doxyfile
@@ -0,0 +1,2579 @@
+# Doxyfile 1.8.17
+
+# This file describes the settings to be used by the documentation system
+# doxygen (www.doxygen.org) for a project.
+#
+# All text after a double hash (##) is considered a comment and is placed in
+# front of the TAG it is preceding.
+#
+# All text after a single hash (#) is considered a comment and will be ignored.
+# The format is:
+# TAG = value [value, ...]
+# For lists, items can also be appended using:
+# TAG += value [value, ...]
+# Values that contain spaces should be placed between quotes (\" \").
+
+#---------------------------------------------------------------------------
+# Project related configuration options
+#---------------------------------------------------------------------------
+
+# This tag specifies the encoding used for all characters in the configuration
+# file that follow. The default is UTF-8 which is also the encoding used for all
+# text before the first occurrence of this tag. Doxygen uses libiconv (or the
+# iconv built into libc) for the transcoding. See
+# https://www.gnu.org/software/libiconv/ for the list of possible encodings.
+# The default value is: UTF-8.
+
+DOXYFILE_ENCODING = UTF-8
+
+# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by
+# double-quotes, unless you are using Doxywizard) that should identify the
+# project for which the documentation is generated. This name is used in the
+# title of most generated pages and in a few other places.
+# The default value is: My Project.
+
+PROJECT_NAME = "llama.cpp"
+
+# The PROJECT_NUMBER tag can be used to enter a project or revision number. This
+# could be handy for archiving the generated documentation or if some version
+# control system is used.
+
+PROJECT_NUMBER =
+
+# Using the PROJECT_BRIEF tag one can provide an optional one line description
+# for a project that appears at the top of each page and should give viewer a
+# quick idea about the purpose of the project. Keep the description short.
+
+PROJECT_BRIEF = "llama inference engine"
+
+# With the PROJECT_LOGO tag one can specify a logo or an icon that is included
+# in the documentation. The maximum height of the logo should not exceed 55
+# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy
+# the logo to the output directory.
+
+PROJECT_LOGO =
+
+# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path
+# into which the generated documentation will be written. If a relative path is
+# entered, it will be relative to the location where doxygen was started. If
+# left blank the current directory will be used.
+
+OUTPUT_DIRECTORY = docs
+
+# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub-
+# directories (in 2 levels) under the output directory of each output format and
+# will distribute the generated files over these directories. Enabling this
+# option can be useful when feeding doxygen a huge amount of source files, where
+# putting all generated files in the same directory would otherwise causes
+# performance problems for the file system.
+# The default value is: NO.
+
+CREATE_SUBDIRS = NO
+
+# If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII
+# characters to appear in the names of generated files. If set to NO, non-ASCII
+# characters will be escaped, for example _xE3_x81_x84 will be used for Unicode
+# U+3044.
+# The default value is: NO.
+
+ALLOW_UNICODE_NAMES = NO
+
+# The OUTPUT_LANGUAGE tag is used to specify the language in which all
+# documentation generated by doxygen is written. Doxygen will use this
+# information to generate all constant output in the proper language.
+# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Catalan, Chinese,
+# Chinese-Traditional, Croatian, Czech, Danish, Dutch, English (United States),
+# Esperanto, Farsi (Persian), Finnish, French, German, Greek, Hungarian,
+# Indonesian, Italian, Japanese, Japanese-en (Japanese with English messages),
+# Korean, Korean-en (Korean with English messages), Latvian, Lithuanian,
+# Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, Romanian, Russian,
+# Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, Swedish, Turkish,
+# Ukrainian and Vietnamese.
+# The default value is: English.
+
+OUTPUT_LANGUAGE = English
+
+# The OUTPUT_TEXT_DIRECTION tag is used to specify the direction in which all
+# documentation generated by doxygen is written. Doxygen will use this
+# information to generate all generated output in the proper direction.
+# Possible values are: None, LTR, RTL and Context.
+# The default value is: None.
+
+OUTPUT_TEXT_DIRECTION = None
+
+# If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member
+# descriptions after the members that are listed in the file and class
+# documentation (similar to Javadoc). Set to NO to disable this.
+# The default value is: YES.
+
+BRIEF_MEMBER_DESC = YES
+
+# If the REPEAT_BRIEF tag is set to YES, doxygen will prepend the brief
+# description of a member or function before the detailed description
+#
+# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the
+# brief descriptions will be completely suppressed.
+# The default value is: YES.
+
+REPEAT_BRIEF = YES
+
+# This tag implements a quasi-intelligent brief description abbreviator that is
+# used to form the text in various listings. Each string in this list, if found
+# as the leading text of the brief description, will be stripped from the text
+# and the result, after processing the whole list, is used as the annotated
+# text. Otherwise, the brief description is used as-is. If left blank, the
+# following values are used ($name is automatically replaced with the name of
+# the entity):The $name class, The $name widget, The $name file, is, provides,
+# specifies, contains, represents, a, an and the.
+
+ABBREVIATE_BRIEF = "The $name class" \
+ "The $name widget" \
+ "The $name file" \
+ is \
+ provides \
+ specifies \
+ contains \
+ represents \
+ a \
+ an \
+ the
+
+# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then
+# doxygen will generate a detailed section even if there is only a brief
+# description.
+# The default value is: NO.
+
+ALWAYS_DETAILED_SEC = NO
+
+# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all
+# inherited members of a class in the documentation of that class as if those
+# members were ordinary class members. Constructors, destructors and assignment
+# operators of the base classes will not be shown.
+# The default value is: NO.
+
+INLINE_INHERITED_MEMB = NO
+
+# If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path
+# before files name in the file list and in the header files. If set to NO the
+# shortest path that makes the file name unique will be used
+# The default value is: YES.
+
+FULL_PATH_NAMES = YES
+
+# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path.
+# Stripping is only done if one of the specified strings matches the left-hand
+# part of the path. The tag can be used to show relative paths in the file list.
+# If left blank the directory from which doxygen is run is used as the path to
+# strip.
+#
+# Note that you can specify absolute paths here, but also relative paths, which
+# will be relative from the directory where doxygen is started.
+# This tag requires that the tag FULL_PATH_NAMES is set to YES.
+
+STRIP_FROM_PATH =
+
+# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the
+# path mentioned in the documentation of a class, which tells the reader which
+# header file to include in order to use a class. If left blank only the name of
+# the header file containing the class definition is used. Otherwise one should
+# specify the list of include paths that are normally passed to the compiler
+# using the -I flag.
+
+STRIP_FROM_INC_PATH =
+
+# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but
+# less readable) file names. This can be useful is your file systems doesn't
+# support long names like on DOS, Mac, or CD-ROM.
+# The default value is: NO.
+
+SHORT_NAMES = NO
+
+# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the
+# first line (until the first dot) of a Javadoc-style comment as the brief
+# description. If set to NO, the Javadoc-style will behave just like regular Qt-
+# style comments (thus requiring an explicit @brief command for a brief
+# description.)
+# The default value is: NO.
+
+JAVADOC_AUTOBRIEF = NO
+
+# If the JAVADOC_BANNER tag is set to YES then doxygen will interpret a line
+# such as
+# /***************
+# as being the beginning of a Javadoc-style comment "banner". If set to NO, the
+# Javadoc-style will behave just like regular comments and it will not be
+# interpreted by doxygen.
+# The default value is: NO.
+
+JAVADOC_BANNER = NO
+
+# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first
+# line (until the first dot) of a Qt-style comment as the brief description. If
+# set to NO, the Qt-style will behave just like regular Qt-style comments (thus
+# requiring an explicit \brief command for a brief description.)
+# The default value is: NO.
+
+QT_AUTOBRIEF = NO
+
+# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a
+# multi-line C++ special comment block (i.e. a block of //! or /// comments) as
+# a brief description. This used to be the default behavior. The new default is
+# to treat a multi-line C++ comment block as a detailed description. Set this
+# tag to YES if you prefer the old behavior instead.
+#
+# Note that setting this tag to YES also means that rational rose comments are
+# not recognized any more.
+# The default value is: NO.
+
+MULTILINE_CPP_IS_BRIEF = NO
+
+# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the
+# documentation from any documented member that it re-implements.
+# The default value is: YES.
+
+INHERIT_DOCS = YES
+
+# If the SEPARATE_MEMBER_PAGES tag is set to YES then doxygen will produce a new
+# page for each member. If set to NO, the documentation of a member will be part
+# of the file/class/namespace that contains it.
+# The default value is: NO.
+
+SEPARATE_MEMBER_PAGES = NO
+
+# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen
+# uses this value to replace tabs by spaces in code fragments.
+# Minimum value: 1, maximum value: 16, default value: 4.
+
+TAB_SIZE = 4
+
+# This tag can be used to specify a number of aliases that act as commands in
+# the documentation. An alias has the form:
+# name=value
+# For example adding
+# "sideeffect=@par Side Effects:\n"
+# will allow you to put the command \sideeffect (or @sideeffect) in the
+# documentation, which will result in a user-defined paragraph with heading
+# "Side Effects:". You can put \n's in the value part of an alias to insert
+# newlines (in the resulting output). You can put ^^ in the value part of an
+# alias to insert a newline as if a physical newline was in the original file.
+# When you need a literal { or } or , in the value part of an alias you have to
+# escape them by means of a backslash (\), this can lead to conflicts with the
+# commands \{ and \} for these it is advised to use the version @{ and @} or use
+# a double escape (\\{ and \\})
+
+ALIASES =
+
+# This tag can be used to specify a number of word-keyword mappings (TCL only).
+# A mapping has the form "name=value". For example adding "class=itcl::class"
+# will allow you to use the command class in the itcl::class meaning.
+
+TCL_SUBST =
+
+# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources
+# only. Doxygen will then generate output that is more tailored for C. For
+# instance, some of the names that are used will be different. The list of all
+# members will be omitted, etc.
+# The default value is: NO.
+
+OPTIMIZE_OUTPUT_FOR_C = NO
+
+# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or
+# Python sources only. Doxygen will then generate output that is more tailored
+# for that language. For instance, namespaces will be presented as packages,
+# qualified scopes will look different, etc.
+# The default value is: NO.
+
+OPTIMIZE_OUTPUT_JAVA = NO
+
+# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran
+# sources. Doxygen will then generate output that is tailored for Fortran.
+# The default value is: NO.
+
+OPTIMIZE_FOR_FORTRAN = NO
+
+# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL
+# sources. Doxygen will then generate output that is tailored for VHDL.
+# The default value is: NO.
+
+OPTIMIZE_OUTPUT_VHDL = NO
+
+# Set the OPTIMIZE_OUTPUT_SLICE tag to YES if your project consists of Slice
+# sources only. Doxygen will then generate output that is more tailored for that
+# language. For instance, namespaces will be presented as modules, types will be
+# separated into more groups, etc.
+# The default value is: NO.
+
+OPTIMIZE_OUTPUT_SLICE = NO
+
+# Doxygen selects the parser to use depending on the extension of the files it
+# parses. With this tag you can assign which parser to use for a given
+# extension. Doxygen has a built-in mapping, but you can override or extend it
+# using this tag. The format is ext=language, where ext is a file extension, and
+# language is one of the parsers supported by doxygen: IDL, Java, JavaScript,
+# Csharp (C#), C, C++, D, PHP, md (Markdown), Objective-C, Python, Slice,
+# Fortran (fixed format Fortran: FortranFixed, free formatted Fortran:
+# FortranFree, unknown formatted Fortran: Fortran. In the later case the parser
+# tries to guess whether the code is fixed or free formatted code, this is the
+# default for Fortran type files), VHDL, tcl. For instance to make doxygen treat
+# .inc files as Fortran files (default is PHP), and .f files as C (default is
+# Fortran), use: inc=Fortran f=C.
+#
+# Note: For files without extension you can use no_extension as a placeholder.
+#
+# Note that for custom extensions you also need to set FILE_PATTERNS otherwise
+# the files are not read by doxygen.
+
+EXTENSION_MAPPING =
+
+# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments
+# according to the Markdown format, which allows for more readable
+# documentation. See https://daringfireball.net/projects/markdown/ for details.
+# The output of markdown processing is further processed by doxygen, so you can
+# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in
+# case of backward compatibilities issues.
+# The default value is: YES.
+
+MARKDOWN_SUPPORT = YES
+
+# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up
+# to that level are automatically included in the table of contents, even if
+# they do not have an id attribute.
+# Note: This feature currently applies only to Markdown headings.
+# Minimum value: 0, maximum value: 99, default value: 5.
+# This tag requires that the tag MARKDOWN_SUPPORT is set to YES.
+
+TOC_INCLUDE_HEADINGS = 5
+
+# When enabled doxygen tries to link words that correspond to documented
+# classes, or namespaces to their corresponding documentation. Such a link can
+# be prevented in individual cases by putting a % sign in front of the word or
+# globally by setting AUTOLINK_SUPPORT to NO.
+# The default value is: YES.
+
+AUTOLINK_SUPPORT = YES
+
+# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want
+# to include (a tag file for) the STL sources as input, then you should set this
+# tag to YES in order to let doxygen match functions declarations and
+# definitions whose arguments contain STL classes (e.g. func(std::string);
+# versus func(std::string) {}). This also make the inheritance and collaboration
+# diagrams that involve STL classes more complete and accurate.
+# The default value is: NO.
+
+BUILTIN_STL_SUPPORT = NO
+
+# If you use Microsoft's C++/CLI language, you should set this option to YES to
+# enable parsing support.
+# The default value is: NO.
+
+CPP_CLI_SUPPORT = NO
+
+# Set the SIP_SUPPORT tag to YES if your project consists of sip (see:
+# https://www.riverbankcomputing.com/software/sip/intro) sources only. Doxygen
+# will parse them like normal C++ but will assume all classes use public instead
+# of private inheritance when no explicit protection keyword is present.
+# The default value is: NO.
+
+SIP_SUPPORT = NO
+
+# For Microsoft's IDL there are propget and propput attributes to indicate
+# getter and setter methods for a property. Setting this option to YES will make
+# doxygen to replace the get and set methods by a property in the documentation.
+# This will only work if the methods are indeed getting or setting a simple
+# type. If this is not the case, or you want to show the methods anyway, you
+# should set this option to NO.
+# The default value is: YES.
+
+IDL_PROPERTY_SUPPORT = YES
+
+# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC
+# tag is set to YES then doxygen will reuse the documentation of the first
+# member in the group (if any) for the other members of the group. By default
+# all members of a group must be documented explicitly.
+# The default value is: NO.
+
+DISTRIBUTE_GROUP_DOC = NO
+
+# If one adds a struct or class to a group and this option is enabled, then also
+# any nested class or struct is added to the same group. By default this option
+# is disabled and one has to add nested compounds explicitly via \ingroup.
+# The default value is: NO.
+
+GROUP_NESTED_COMPOUNDS = NO
+
+# Set the SUBGROUPING tag to YES to allow class member groups of the same type
+# (for instance a group of public functions) to be put as a subgroup of that
+# type (e.g. under the Public Functions section). Set it to NO to prevent
+# subgrouping. Alternatively, this can be done per class using the
+# \nosubgrouping command.
+# The default value is: YES.
+
+SUBGROUPING = YES
+
+# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions
+# are shown inside the group in which they are included (e.g. using \ingroup)
+# instead of on a separate page (for HTML and Man pages) or section (for LaTeX
+# and RTF).
+#
+# Note that this feature does not work in combination with
+# SEPARATE_MEMBER_PAGES.
+# The default value is: NO.
+
+INLINE_GROUPED_CLASSES = NO
+
+# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions
+# with only public data fields or simple typedef fields will be shown inline in
+# the documentation of the scope in which they are defined (i.e. file,
+# namespace, or group documentation), provided this scope is documented. If set
+# to NO, structs, classes, and unions are shown on a separate page (for HTML and
+# Man pages) or section (for LaTeX and RTF).
+# The default value is: NO.
+
+INLINE_SIMPLE_STRUCTS = NO
+
+# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or
+# enum is documented as struct, union, or enum with the name of the typedef. So
+# typedef struct TypeS {} TypeT, will appear in the documentation as a struct
+# with name TypeT. When disabled the typedef will appear as a member of a file,
+# namespace, or class. And the struct will be named TypeS. This can typically be
+# useful for C code in case the coding convention dictates that all compound
+# types are typedef'ed and only the typedef is referenced, never the tag name.
+# The default value is: NO.
+
+TYPEDEF_HIDES_STRUCT = NO
+
+# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This
+# cache is used to resolve symbols given their name and scope. Since this can be
+# an expensive process and often the same symbol appears multiple times in the
+# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small
+# doxygen will become slower. If the cache is too large, memory is wasted. The
+# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range
+# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536
+# symbols. At the end of a run doxygen will report the cache usage and suggest
+# the optimal cache size from a speed point of view.
+# Minimum value: 0, maximum value: 9, default value: 0.
+
+LOOKUP_CACHE_SIZE = 0
+
+#---------------------------------------------------------------------------
+# Build related configuration options
+#---------------------------------------------------------------------------
+
+# If the EXTRACT_ALL tag is set to YES, doxygen will assume all entities in
+# documentation are documented, even if no documentation was available. Private
+# class members and static file members will be hidden unless the
+# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES.
+# Note: This will also disable the warnings about undocumented members that are
+# normally produced when WARNINGS is set to YES.
+# The default value is: NO.
+
+EXTRACT_ALL = YES
+
+# If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will
+# be included in the documentation.
+# The default value is: NO.
+
+EXTRACT_PRIVATE = YES
+
+# If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual
+# methods of a class will be included in the documentation.
+# The default value is: NO.
+
+EXTRACT_PRIV_VIRTUAL = YES
+
+# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal
+# scope will be included in the documentation.
+# The default value is: NO.
+
+EXTRACT_PACKAGE = YES
+
+# If the EXTRACT_STATIC tag is set to YES, all static members of a file will be
+# included in the documentation.
+# The default value is: NO.
+
+EXTRACT_STATIC = YES
+
+# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined
+# locally in source files will be included in the documentation. If set to NO,
+# only classes defined in header files are included. Does not have any effect
+# for Java sources.
+# The default value is: YES.
+
+EXTRACT_LOCAL_CLASSES = YES
+
+# This flag is only useful for Objective-C code. If set to YES, local methods,
+# which are defined in the implementation section but not in the interface are
+# included in the documentation. If set to NO, only methods in the interface are
+# included.
+# The default value is: NO.
+
+EXTRACT_LOCAL_METHODS = YES
+
+# If this flag is set to YES, the members of anonymous namespaces will be
+# extracted and appear in the documentation as a namespace called
+# 'anonymous_namespace{file}', where file will be replaced with the base name of
+# the file that contains the anonymous namespace. By default anonymous namespace
+# are hidden.
+# The default value is: NO.
+
+EXTRACT_ANON_NSPACES = NO
+
+# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all
+# undocumented members inside documented classes or files. If set to NO these
+# members will be included in the various overviews, but no documentation
+# section is generated. This option has no effect if EXTRACT_ALL is enabled.
+# The default value is: NO.
+
+HIDE_UNDOC_MEMBERS = NO
+
+# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all
+# undocumented classes that are normally visible in the class hierarchy. If set
+# to NO, these classes will be included in the various overviews. This option
+# has no effect if EXTRACT_ALL is enabled.
+# The default value is: NO.
+
+HIDE_UNDOC_CLASSES = NO
+
+# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend
+# declarations. If set to NO, these declarations will be included in the
+# documentation.
+# The default value is: NO.
+
+HIDE_FRIEND_COMPOUNDS = NO
+
+# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any
+# documentation blocks found inside the body of a function. If set to NO, these
+# blocks will be appended to the function's detailed documentation block.
+# The default value is: NO.
+
+HIDE_IN_BODY_DOCS = NO
+
+# The INTERNAL_DOCS tag determines if documentation that is typed after a
+# \internal command is included. If the tag is set to NO then the documentation
+# will be excluded. Set it to YES to include the internal documentation.
+# The default value is: NO.
+
+INTERNAL_DOCS = NO
+
+# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file
+# names in lower-case letters. If set to YES, upper-case letters are also
+# allowed. This is useful if you have classes or files whose names only differ
+# in case and if your file system supports case sensitive file names. Windows
+# (including Cygwin) ands Mac users are advised to set this option to NO.
+# The default value is: system dependent.
+
+CASE_SENSE_NAMES = YES
+
+# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with
+# their full class and namespace scopes in the documentation. If set to YES, the
+# scope will be hidden.
+# The default value is: NO.
+
+HIDE_SCOPE_NAMES = NO
+
+# If the HIDE_COMPOUND_REFERENCE tag is set to NO (default) then doxygen will
+# append additional text to a page's title, such as Class Reference. If set to
+# YES the compound reference will be hidden.
+# The default value is: NO.
+
+HIDE_COMPOUND_REFERENCE= NO
+
+# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of
+# the files that are included by a file in the documentation of that file.
+# The default value is: YES.
+
+SHOW_INCLUDE_FILES = YES
+
+# If the SHOW_GROUPED_MEMB_INC tag is set to YES then Doxygen will add for each
+# grouped member an include statement to the documentation, telling the reader
+# which file to include in order to use the member.
+# The default value is: NO.
+
+SHOW_GROUPED_MEMB_INC = NO
+
+# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include
+# files with double quotes in the documentation rather than with sharp brackets.
+# The default value is: NO.
+
+FORCE_LOCAL_INCLUDES = NO
+
+# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the
+# documentation for inline members.
+# The default value is: YES.
+
+INLINE_INFO = YES
+
+# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the
+# (detailed) documentation of file and class members alphabetically by member
+# name. If set to NO, the members will appear in declaration order.
+# The default value is: YES.
+
+SORT_MEMBER_DOCS = YES
+
+# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief
+# descriptions of file, namespace and class members alphabetically by member
+# name. If set to NO, the members will appear in declaration order. Note that
+# this will also influence the order of the classes in the class list.
+# The default value is: NO.
+
+SORT_BRIEF_DOCS = NO
+
+# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the
+# (brief and detailed) documentation of class members so that constructors and
+# destructors are listed first. If set to NO the constructors will appear in the
+# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS.
+# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief
+# member documentation.
+# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting
+# detailed member documentation.
+# The default value is: NO.
+
+SORT_MEMBERS_CTORS_1ST = NO
+
+# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy
+# of group names into alphabetical order. If set to NO the group names will
+# appear in their defined order.
+# The default value is: NO.
+
+SORT_GROUP_NAMES = NO
+
+# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by
+# fully-qualified names, including namespaces. If set to NO, the class list will
+# be sorted only by class name, not including the namespace part.
+# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES.
+# Note: This option applies only to the class list, not to the alphabetical
+# list.
+# The default value is: NO.
+
+SORT_BY_SCOPE_NAME = NO
+
+# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper
+# type resolution of all parameters of a function it will reject a match between
+# the prototype and the implementation of a member function even if there is
+# only one candidate or it is obvious which candidate to choose by doing a
+# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still
+# accept a match between prototype and implementation in such cases.
+# The default value is: NO.
+
+STRICT_PROTO_MATCHING = NO
+
+# The GENERATE_TODOLIST tag can be used to enable (YES) or disable (NO) the todo
+# list. This list is created by putting \todo commands in the documentation.
+# The default value is: YES.
+
+GENERATE_TODOLIST = YES
+
+# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test
+# list. This list is created by putting \test commands in the documentation.
+# The default value is: YES.
+
+GENERATE_TESTLIST = YES
+
+# The GENERATE_BUGLIST tag can be used to enable (YES) or disable (NO) the bug
+# list. This list is created by putting \bug commands in the documentation.
+# The default value is: YES.
+
+GENERATE_BUGLIST = YES
+
+# The GENERATE_DEPRECATEDLIST tag can be used to enable (YES) or disable (NO)
+# the deprecated list. This list is created by putting \deprecated commands in
+# the documentation.
+# The default value is: YES.
+
+GENERATE_DEPRECATEDLIST= YES
+
+# The ENABLED_SECTIONS tag can be used to enable conditional documentation
+# sections, marked by \if <section_label> ... \endif and \cond <section_label>
+# ... \endcond blocks.
+
+ENABLED_SECTIONS =
+
+# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the
+# initial value of a variable or macro / define can have for it to appear in the
+# documentation. If the initializer consists of more lines than specified here
+# it will be hidden. Use a value of 0 to hide initializers completely. The
+# appearance of the value of individual variables and macros / defines can be
+# controlled using \showinitializer or \hideinitializer command in the
+# documentation regardless of this setting.
+# Minimum value: 0, maximum value: 10000, default value: 30.
+
+MAX_INITIALIZER_LINES = 30
+
+# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at
+# the bottom of the documentation of classes and structs. If set to YES, the
+# list will mention the files that were used to generate the documentation.
+# The default value is: YES.
+
+SHOW_USED_FILES = YES
+
+# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This
+# will remove the Files entry from the Quick Index and from the Folder Tree View
+# (if specified).
+# The default value is: YES.
+
+SHOW_FILES = YES
+
+# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces
+# page. This will remove the Namespaces entry from the Quick Index and from the
+# Folder Tree View (if specified).
+# The default value is: YES.
+
+SHOW_NAMESPACES = YES
+
+# The FILE_VERSION_FILTER tag can be used to specify a program or script that
+# doxygen should invoke to get the current version for each file (typically from
+# the version control system). Doxygen will invoke the program by executing (via
+# popen()) the command command input-file, where command is the value of the
+# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided
+# by doxygen. Whatever the program writes to standard output is used as the file
+# version. For an example see the documentation.
+
+FILE_VERSION_FILTER =
+
+# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed
+# by doxygen. The layout file controls the global structure of the generated
+# output files in an output format independent way. To create the layout file
+# that represents doxygen's defaults, run doxygen with the -l option. You can
+# optionally specify a file name after the option, if omitted DoxygenLayout.xml
+# will be used as the name of the layout file.
+#
+# Note that if you run doxygen from a directory containing a file called
+# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE
+# tag is left empty.
+
+LAYOUT_FILE =
+
+# The CITE_BIB_FILES tag can be used to specify one or more bib files containing
+# the reference definitions. This must be a list of .bib files. The .bib
+# extension is automatically appended if omitted. This requires the bibtex tool
+# to be installed. See also https://en.wikipedia.org/wiki/BibTeX for more info.
+# For LaTeX the style of the bibliography can be controlled using
+# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the
+# search path. See also \cite for info how to create references.
+
+CITE_BIB_FILES =
+
+#---------------------------------------------------------------------------
+# Configuration options related to warning and progress messages
+#---------------------------------------------------------------------------
+
+# The QUIET tag can be used to turn on/off the messages that are generated to
+# standard output by doxygen. If QUIET is set to YES this implies that the
+# messages are off.
+# The default value is: NO.
+
+QUIET = NO
+
+# The WARNINGS tag can be used to turn on/off the warning messages that are
+# generated to standard error (stderr) by doxygen. If WARNINGS is set to YES
+# this implies that the warnings are on.
+#
+# Tip: Turn warnings on while writing the documentation.
+# The default value is: YES.
+
+WARNINGS = YES
+
+# If the WARN_IF_UNDOCUMENTED tag is set to YES then doxygen will generate
+# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag
+# will automatically be disabled.
+# The default value is: YES.
+
+WARN_IF_UNDOCUMENTED = YES
+
+# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for
+# potential errors in the documentation, such as not documenting some parameters
+# in a documented function, or documenting parameters that don't exist or using
+# markup commands wrongly.
+# The default value is: YES.
+
+WARN_IF_DOC_ERROR = YES
+
+# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that
+# are documented, but have no documentation for their parameters or return
+# value. If set to NO, doxygen will only warn about wrong or incomplete
+# parameter documentation, but not about the absence of documentation. If
+# EXTRACT_ALL is set to YES then this flag will automatically be disabled.
+# The default value is: NO.
+
+WARN_NO_PARAMDOC = NO
+
+# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when
+# a warning is encountered.
+# The default value is: NO.
+
+WARN_AS_ERROR = NO
+
+# The WARN_FORMAT tag determines the format of the warning messages that doxygen
+# can produce. The string should contain the $file, $line, and $text tags, which
+# will be replaced by the file and line number from which the warning originated
+# and the warning text. Optionally the format may contain $version, which will
+# be replaced by the version of the file (if it could be obtained via
+# FILE_VERSION_FILTER)
+# The default value is: $file:$line: $text.
+
+WARN_FORMAT = "$file:$line: $text"
+
+# The WARN_LOGFILE tag can be used to specify a file to which warning and error
+# messages should be written. If left blank the output is written to standard
+# error (stderr).
+
+WARN_LOGFILE =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the input files
+#---------------------------------------------------------------------------
+
+# The INPUT tag is used to specify the files and/or directories that contain
+# documented source files. You may enter file names like myfile.cpp or
+# directories like /usr/src/myproject. Separate the files or directories with
+# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING
+# Note: If this tag is empty the current directory is searched.
+
+INPUT =
+
+# This tag can be used to specify the character encoding of the source files
+# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
+# libiconv (or the iconv built into libc) for the transcoding. See the libiconv
+# documentation (see: https://www.gnu.org/software/libiconv/) for the list of
+# possible encodings.
+# The default value is: UTF-8.
+
+INPUT_ENCODING = UTF-8
+
+# If the value of the INPUT tag contains directories, you can use the
+# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and
+# *.h) to filter out the source-files in the directories.
+#
+# Note that for custom extensions or not directly supported extensions you also
+# need to set EXTENSION_MAPPING for the extension otherwise the files are not
+# read by doxygen.
+#
+# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp,
+# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h,
+# *.hh, *.hxx, *.hpp, *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc,
+# *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C comment),
+# *.doc (to be provided as doxygen C comment), *.txt (to be provided as doxygen
+# C comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f, *.for, *.tcl, *.vhd,
+# *.vhdl, *.ucf, *.qsf and *.ice.
+
+FILE_PATTERNS = *.c \
+ *.cc \
+ *.cxx \
+ *.cpp \
+ *.c++ \
+ *.java \
+ *.ii \
+ *.ixx \
+ *.ipp \
+ *.i++ \
+ *.inl \
+ *.idl \
+ *.ddl \
+ *.odl \
+ *.h \
+ *.hh \
+ *.hxx \
+ *.hpp \
+ *.h++ \
+ *.cs \
+ *.d \
+ *.php \
+ *.php4 \
+ *.php5 \
+ *.phtml \
+ *.inc \
+ *.m \
+ *.markdown \
+ *.md \
+ *.mm \
+ *.dox \
+ *.doc \
+ *.txt \
+ *.py \
+ *.pyw \
+ *.f90 \
+ *.f95 \
+ *.f03 \
+ *.f08 \
+ *.f \
+ *.for \
+ *.tcl \
+ *.vhd \
+ *.vhdl \
+ *.ucf \
+ *.qsf \
+ *.ice
+
+# The RECURSIVE tag can be used to specify whether or not subdirectories should
+# be searched for input files as well.
+# The default value is: NO.
+
+RECURSIVE = YES
+
+# The EXCLUDE tag can be used to specify files and/or directories that should be
+# excluded from the INPUT source files. This way you can easily exclude a
+# subdirectory from a directory tree whose root is specified with the INPUT tag.
+#
+# Note that relative paths are relative to the directory from which doxygen is
+# run.
+
+EXCLUDE =
+
+# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or
+# directories that are symbolic links (a Unix file system feature) are excluded
+# from the input.
+# The default value is: NO.
+
+EXCLUDE_SYMLINKS = NO
+
+# If the value of the INPUT tag contains directories, you can use the
+# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude
+# certain files from those directories.
+#
+# Note that the wildcards are matched against the file with absolute path, so to
+# exclude all test directories for example use the pattern */test/*
+
+EXCLUDE_PATTERNS =
+
+# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names
+# (namespaces, classes, functions, etc.) that should be excluded from the
+# output. The symbol name can be a fully qualified name, a word, or if the
+# wildcard * is used, a substring. Examples: ANamespace, AClass,
+# AClass::ANamespace, ANamespace::*Test
+#
+# Note that the wildcards are matched against the file with absolute path, so to
+# exclude all test directories use the pattern */test/*
+
+EXCLUDE_SYMBOLS =
+
+# The EXAMPLE_PATH tag can be used to specify one or more files or directories
+# that contain example code fragments that are included (see the \include
+# command).
+
+EXAMPLE_PATH =
+
+# If the value of the EXAMPLE_PATH tag contains directories, you can use the
+# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and
+# *.h) to filter out the source-files in the directories. If left blank all
+# files are included.
+
+EXAMPLE_PATTERNS = *
+
+# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be
+# searched for input files to be used with the \include or \dontinclude commands
+# irrespective of the value of the RECURSIVE tag.
+# The default value is: NO.
+
+EXAMPLE_RECURSIVE = NO
+
+# The IMAGE_PATH tag can be used to specify one or more files or directories
+# that contain images that are to be included in the documentation (see the
+# \image command).
+
+IMAGE_PATH =
+
+# The INPUT_FILTER tag can be used to specify a program that doxygen should
+# invoke to filter for each input file. Doxygen will invoke the filter program
+# by executing (via popen()) the command:
+#
+# <filter> <input-file>
+#
+# where <filter> is the value of the INPUT_FILTER tag, and <input-file> is the
+# name of an input file. Doxygen will then use the output that the filter
+# program writes to standard output. If FILTER_PATTERNS is specified, this tag
+# will be ignored.
+#
+# Note that the filter must not add or remove lines; it is applied before the
+# code is scanned, but not when the output code is generated. If lines are added
+# or removed, the anchors will not be placed correctly.
+#
+# Note that for custom extensions or not directly supported extensions you also
+# need to set EXTENSION_MAPPING for the extension otherwise the files are not
+# properly processed by doxygen.
+
+INPUT_FILTER =
+
+# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern
+# basis. Doxygen will compare the file name with each pattern and apply the
+# filter if there is a match. The filters are a list of the form: pattern=filter
+# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how
+# filters are used. If the FILTER_PATTERNS tag is empty or if none of the
+# patterns match the file name, INPUT_FILTER is applied.
+#
+# Note that for custom extensions or not directly supported extensions you also
+# need to set EXTENSION_MAPPING for the extension otherwise the files are not
+# properly processed by doxygen.
+
+FILTER_PATTERNS =
+
+# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using
+# INPUT_FILTER) will also be used to filter the input files that are used for
+# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES).
+# The default value is: NO.
+
+FILTER_SOURCE_FILES = NO
+
+# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file
+# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and
+# it is also possible to disable source filtering for a specific pattern using
+# *.ext= (so without naming a filter).
+# This tag requires that the tag FILTER_SOURCE_FILES is set to YES.
+
+FILTER_SOURCE_PATTERNS =
+
+# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that
+# is part of the input, its contents will be placed on the main page
+# (index.html). This can be useful if you have a project on for instance GitHub
+# and want to reuse the introduction page also for the doxygen output.
+
+USE_MDFILE_AS_MAINPAGE =
+
+#---------------------------------------------------------------------------
+# Configuration options related to source browsing
+#---------------------------------------------------------------------------
+
+# If the SOURCE_BROWSER tag is set to YES then a list of source files will be
+# generated. Documented entities will be cross-referenced with these sources.
+#
+# Note: To get rid of all source code in the generated output, make sure that
+# also VERBATIM_HEADERS is set to NO.
+# The default value is: NO.
+
+SOURCE_BROWSER = NO
+
+# Setting the INLINE_SOURCES tag to YES will include the body of functions,
+# classes and enums directly into the documentation.
+# The default value is: NO.
+
+INLINE_SOURCES = NO
+
+# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any
+# special comment blocks from generated source code fragments. Normal C, C++ and
+# Fortran comments will always remain visible.
+# The default value is: YES.
+
+STRIP_CODE_COMMENTS = YES
+
+# If the REFERENCED_BY_RELATION tag is set to YES then for each documented
+# entity all documented functions referencing it will be listed.
+# The default value is: NO.
+
+REFERENCED_BY_RELATION = NO
+
+# If the REFERENCES_RELATION tag is set to YES then for each documented function
+# all documented entities called/used by that function will be listed.
+# The default value is: NO.
+
+REFERENCES_RELATION = NO
+
+# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set
+# to YES then the hyperlinks from functions in REFERENCES_RELATION and
+# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will
+# link to the documentation.
+# The default value is: YES.
+
+REFERENCES_LINK_SOURCE = YES
+
+# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the
+# source code will show a tooltip with additional information such as prototype,
+# brief description and links to the definition and documentation. Since this
+# will make the HTML file larger and loading of large files a bit slower, you
+# can opt to disable this feature.
+# The default value is: YES.
+# This tag requires that the tag SOURCE_BROWSER is set to YES.
+
+SOURCE_TOOLTIPS = YES
+
+# If the USE_HTAGS tag is set to YES then the references to source code will
+# point to the HTML generated by the htags(1) tool instead of doxygen built-in
+# source browser. The htags tool is part of GNU's global source tagging system
+# (see https://www.gnu.org/software/global/global.html). You will need version
+# 4.8.6 or higher.
+#
+# To use it do the following:
+# - Install the latest version of global
+# - Enable SOURCE_BROWSER and USE_HTAGS in the configuration file
+# - Make sure the INPUT points to the root of the source tree
+# - Run doxygen as normal
+#
+# Doxygen will invoke htags (and that will in turn invoke gtags), so these
+# tools must be available from the command line (i.e. in the search path).
+#
+# The result: instead of the source browser generated by doxygen, the links to
+# source code will now point to the output of htags.
+# The default value is: NO.
+# This tag requires that the tag SOURCE_BROWSER is set to YES.
+
+USE_HTAGS = NO
+
+# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a
+# verbatim copy of the header file for each class for which an include is
+# specified. Set to NO to disable this.
+# See also: Section \class.
+# The default value is: YES.
+
+VERBATIM_HEADERS = YES
+
+# If the CLANG_ASSISTED_PARSING tag is set to YES then doxygen will use the
+# clang parser (see: http://clang.llvm.org/) for more accurate parsing at the
+# cost of reduced performance. This can be particularly helpful with template
+# rich C++ code for which doxygen's built-in parser lacks the necessary type
+# information.
+# Note: The availability of this option depends on whether or not doxygen was
+# generated with the -Duse_libclang=ON option for CMake.
+# The default value is: NO.
+
+CLANG_ASSISTED_PARSING = NO
+
+# If clang assisted parsing is enabled you can provide the compiler with command
+# line options that you would normally use when invoking the compiler. Note that
+# the include paths will already be set by doxygen for the files and directories
+# specified with INPUT and INCLUDE_PATH.
+# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES.
+
+CLANG_OPTIONS =
+
+# If clang assisted parsing is enabled you can provide the clang parser with the
+# path to the compilation database (see:
+# http://clang.llvm.org/docs/HowToSetupToolingForLLVM.html) used when the files
+# were built. This is equivalent to specifying the "-p" option to a clang tool,
+# such as clang-check. These options will then be passed to the parser.
+# Note: The availability of this option depends on whether or not doxygen was
+# generated with the -Duse_libclang=ON option for CMake.
+
+CLANG_DATABASE_PATH =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the alphabetical class index
+#---------------------------------------------------------------------------
+
+# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all
+# compounds will be generated. Enable this if the project contains a lot of
+# classes, structs, unions or interfaces.
+# The default value is: YES.
+
+ALPHABETICAL_INDEX = YES
+
+# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in
+# which the alphabetical index list will be split.
+# Minimum value: 1, maximum value: 20, default value: 5.
+# This tag requires that the tag ALPHABETICAL_INDEX is set to YES.
+
+COLS_IN_ALPHA_INDEX = 5
+
+# In case all classes in a project start with a common prefix, all classes will
+# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag
+# can be used to specify a prefix (or a list of prefixes) that should be ignored
+# while generating the index headers.
+# This tag requires that the tag ALPHABETICAL_INDEX is set to YES.
+
+IGNORE_PREFIX =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the HTML output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output
+# The default value is: YES.
+
+GENERATE_HTML = YES
+
+# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a
+# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
+# it.
+# The default directory is: html.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_OUTPUT = html
+
+# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each
+# generated HTML page (for example: .htm, .php, .asp).
+# The default value is: .html.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_FILE_EXTENSION = .html
+
+# The HTML_HEADER tag can be used to specify a user-defined HTML header file for
+# each generated HTML page. If the tag is left blank doxygen will generate a
+# standard header.
+#
+# To get valid HTML the header file that includes any scripts and style sheets
+# that doxygen needs, which is dependent on the configuration options used (e.g.
+# the setting GENERATE_TREEVIEW). It is highly recommended to start with a
+# default header using
+# doxygen -w html new_header.html new_footer.html new_stylesheet.css
+# YourConfigFile
+# and then modify the file new_header.html. See also section "Doxygen usage"
+# for information on how to generate the default header that doxygen normally
+# uses.
+# Note: The header is subject to change so you typically have to regenerate the
+# default header when upgrading to a newer version of doxygen. For a description
+# of the possible markers and block names see the documentation.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_HEADER =
+
+# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each
+# generated HTML page. If the tag is left blank doxygen will generate a standard
+# footer. See HTML_HEADER for more information on how to generate a default
+# footer and what special commands can be used inside the footer. See also
+# section "Doxygen usage" for information on how to generate the default footer
+# that doxygen normally uses.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_FOOTER =
+
+# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style
+# sheet that is used by each HTML page. It can be used to fine-tune the look of
+# the HTML output. If left blank doxygen will generate a default style sheet.
+# See also section "Doxygen usage" for information on how to generate the style
+# sheet that doxygen normally uses.
+# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as
+# it is more robust and this tag (HTML_STYLESHEET) will in the future become
+# obsolete.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_STYLESHEET =
+
+# The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined
+# cascading style sheets that are included after the standard style sheets
+# created by doxygen. Using this option one can overrule certain style aspects.
+# This is preferred over using HTML_STYLESHEET since it does not replace the
+# standard style sheet and is therefore more robust against future updates.
+# Doxygen will copy the style sheet files to the output directory.
+# Note: The order of the extra style sheet files is of importance (e.g. the last
+# style sheet in the list overrules the setting of the previous ones in the
+# list). For an example see the documentation.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_EXTRA_STYLESHEET =
+
+# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or
+# other source files which should be copied to the HTML output directory. Note
+# that these files will be copied to the base HTML output directory. Use the
+# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these
+# files. In the HTML_STYLESHEET file, use the file name only. Also note that the
+# files will be copied as-is; there are no commands or markers available.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_EXTRA_FILES =
+
+# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen
+# will adjust the colors in the style sheet and background images according to
+# this color. Hue is specified as an angle on a colorwheel, see
+# https://en.wikipedia.org/wiki/Hue for more information. For instance the value
+# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300
+# purple, and 360 is red again.
+# Minimum value: 0, maximum value: 359, default value: 220.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_COLORSTYLE_HUE = 220
+
+# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors
+# in the HTML output. For a value of 0 the output will use grayscales only. A
+# value of 255 will produce the most vivid colors.
+# Minimum value: 0, maximum value: 255, default value: 100.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_COLORSTYLE_SAT = 100
+
+# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the
+# luminance component of the colors in the HTML output. Values below 100
+# gradually make the output lighter, whereas values above 100 make the output
+# darker. The value divided by 100 is the actual gamma applied, so 80 represents
+# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not
+# change the gamma.
+# Minimum value: 40, maximum value: 240, default value: 80.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_COLORSTYLE_GAMMA = 80
+
+# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML
+# page will contain the date and time when the page was generated. Setting this
+# to YES can help to show when doxygen was last run and thus if the
+# documentation is up to date.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_TIMESTAMP = NO
+
+# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML
+# documentation will contain a main index with vertical navigation menus that
+# are dynamically created via JavaScript. If disabled, the navigation index will
+# consists of multiple levels of tabs that are statically embedded in every HTML
+# page. Disable this option to support browsers that do not have JavaScript,
+# like the Qt help browser.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_DYNAMIC_MENUS = YES
+
+# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML
+# documentation will contain sections that can be hidden and shown after the
+# page has loaded.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_DYNAMIC_SECTIONS = NO
+
+# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries
+# shown in the various tree structured indices initially; the user can expand
+# and collapse entries dynamically later on. Doxygen will expand the tree to
+# such a level that at most the specified number of entries are visible (unless
+# a fully collapsed tree already exceeds this amount). So setting the number of
+# entries 1 will produce a full collapsed tree by default. 0 is a special value
+# representing an infinite number of entries and will result in a full expanded
+# tree by default.
+# Minimum value: 0, maximum value: 9999, default value: 100.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_INDEX_NUM_ENTRIES = 100
+
+# If the GENERATE_DOCSET tag is set to YES, additional index files will be
+# generated that can be used as input for Apple's Xcode 3 integrated development
+# environment (see: https://developer.apple.com/xcode/), introduced with OSX
+# 10.5 (Leopard). To create a documentation set, doxygen will generate a
+# Makefile in the HTML output directory. Running make will produce the docset in
+# that directory and running make install will install the docset in
+# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at
+# startup. See https://developer.apple.com/library/archive/featuredarticles/Doxy
+# genXcode/_index.html for more information.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+GENERATE_DOCSET = NO
+
+# This tag determines the name of the docset feed. A documentation feed provides
+# an umbrella under which multiple documentation sets from a single provider
+# (such as a company or product suite) can be grouped.
+# The default value is: Doxygen generated docs.
+# This tag requires that the tag GENERATE_DOCSET is set to YES.
+
+DOCSET_FEEDNAME = "Doxygen generated docs"
+
+# This tag specifies a string that should uniquely identify the documentation
+# set bundle. This should be a reverse domain-name style string, e.g.
+# com.mycompany.MyDocSet. Doxygen will append .docset to the name.
+# The default value is: org.doxygen.Project.
+# This tag requires that the tag GENERATE_DOCSET is set to YES.
+
+DOCSET_BUNDLE_ID = org.doxygen.Project
+
+# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify
+# the documentation publisher. This should be a reverse domain-name style
+# string, e.g. com.mycompany.MyDocSet.documentation.
+# The default value is: org.doxygen.Publisher.
+# This tag requires that the tag GENERATE_DOCSET is set to YES.
+
+DOCSET_PUBLISHER_ID = org.doxygen.Publisher
+
+# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher.
+# The default value is: Publisher.
+# This tag requires that the tag GENERATE_DOCSET is set to YES.
+
+DOCSET_PUBLISHER_NAME = Publisher
+
+# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three
+# additional HTML index files: index.hhp, index.hhc, and index.hhk. The
+# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop
+# (see: https://www.microsoft.com/en-us/download/details.aspx?id=21138) on
+# Windows.
+#
+# The HTML Help Workshop contains a compiler that can convert all HTML output
+# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML
+# files are now used as the Windows 98 help format, and will replace the old
+# Windows help format (.hlp) on all Windows platforms in the future. Compressed
+# HTML files also contain an index, a table of contents, and you can search for
+# words in the documentation. The HTML workshop also contains a viewer for
+# compressed HTML files.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+GENERATE_HTMLHELP = NO
+
+# The CHM_FILE tag can be used to specify the file name of the resulting .chm
+# file. You can add a path in front of the file if the result should not be
+# written to the html output directory.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+CHM_FILE =
+
+# The HHC_LOCATION tag can be used to specify the location (absolute path
+# including file name) of the HTML help compiler (hhc.exe). If non-empty,
+# doxygen will try to run the HTML help compiler on the generated index.hhp.
+# The file has to be specified with full path.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+HHC_LOCATION =
+
+# The GENERATE_CHI flag controls if a separate .chi index file is generated
+# (YES) or that it should be included in the master .chm file (NO).
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+GENERATE_CHI = NO
+
+# The CHM_INDEX_ENCODING is used to encode HtmlHelp index (hhk), content (hhc)
+# and project file content.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+CHM_INDEX_ENCODING =
+
+# The BINARY_TOC flag controls whether a binary table of contents is generated
+# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it
+# enables the Previous and Next buttons.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+BINARY_TOC = NO
+
+# The TOC_EXPAND flag can be set to YES to add extra items for group members to
+# the table of contents of the HTML help documentation and to the tree view.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+TOC_EXPAND = NO
+
+# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and
+# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that
+# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help
+# (.qch) of the generated HTML documentation.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+GENERATE_QHP = NO
+
+# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify
+# the file name of the resulting .qch file. The path specified is relative to
+# the HTML output folder.
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QCH_FILE =
+
+# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help
+# Project output. For more information please see Qt Help Project / Namespace
+# (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#namespace).
+# The default value is: org.doxygen.Project.
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHP_NAMESPACE = org.doxygen.Project
+
+# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt
+# Help Project output. For more information please see Qt Help Project / Virtual
+# Folders (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#virtual-
+# folders).
+# The default value is: doc.
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHP_VIRTUAL_FOLDER = doc
+
+# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom
+# filter to add. For more information please see Qt Help Project / Custom
+# Filters (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-
+# filters).
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHP_CUST_FILTER_NAME =
+
+# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the
+# custom filter to add. For more information please see Qt Help Project / Custom
+# Filters (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-
+# filters).
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHP_CUST_FILTER_ATTRS =
+
+# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this
+# project's filter section matches. Qt Help Project / Filter Attributes (see:
+# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#filter-attributes).
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHP_SECT_FILTER_ATTRS =
+
+# The QHG_LOCATION tag can be used to specify the location of Qt's
+# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the
+# generated .qhp file.
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHG_LOCATION =
+
+# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be
+# generated, together with the HTML files, they form an Eclipse help plugin. To
+# install this plugin and make it available under the help contents menu in
+# Eclipse, the contents of the directory containing the HTML and XML files needs
+# to be copied into the plugins directory of eclipse. The name of the directory
+# within the plugins directory should be the same as the ECLIPSE_DOC_ID value.
+# After copying Eclipse needs to be restarted before the help appears.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+GENERATE_ECLIPSEHELP = NO
+
+# A unique identifier for the Eclipse help plugin. When installing the plugin
+# the directory name containing the HTML and XML files should also have this
+# name. Each documentation set should have its own identifier.
+# The default value is: org.doxygen.Project.
+# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES.
+
+ECLIPSE_DOC_ID = org.doxygen.Project
+
+# If you want full control over the layout of the generated HTML pages it might
+# be necessary to disable the index and replace it with your own. The
+# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top
+# of each HTML page. A value of NO enables the index and the value YES disables
+# it. Since the tabs in the index contain the same information as the navigation
+# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+DISABLE_INDEX = NO
+
+# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index
+# structure should be generated to display hierarchical information. If the tag
+# value is set to YES, a side panel will be generated containing a tree-like
+# index structure (just like the one that is generated for HTML Help). For this
+# to work a browser that supports JavaScript, DHTML, CSS and frames is required
+# (i.e. any modern browser). Windows users are probably better off using the
+# HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can
+# further fine-tune the look of the index. As an example, the default style
+# sheet generated by doxygen has an example that shows how to put an image at
+# the root of the tree instead of the PROJECT_NAME. Since the tree basically has
+# the same information as the tab index, you could consider setting
+# DISABLE_INDEX to YES when enabling this option.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+GENERATE_TREEVIEW = NO
+
+# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that
+# doxygen will group on one line in the generated HTML documentation.
+#
+# Note that a value of 0 will completely suppress the enum values from appearing
+# in the overview section.
+# Minimum value: 0, maximum value: 20, default value: 4.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+ENUM_VALUES_PER_LINE = 4
+
+# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used
+# to set the initial width (in pixels) of the frame in which the tree is shown.
+# Minimum value: 0, maximum value: 1500, default value: 250.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+TREEVIEW_WIDTH = 250
+
+# If the EXT_LINKS_IN_WINDOW option is set to YES, doxygen will open links to
+# external symbols imported via tag files in a separate window.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+EXT_LINKS_IN_WINDOW = NO
+
+# Use this tag to change the font size of LaTeX formulas included as images in
+# the HTML documentation. When you change the font size after a successful
+# doxygen run you need to manually remove any form_*.png images from the HTML
+# output directory to force them to be regenerated.
+# Minimum value: 8, maximum value: 50, default value: 10.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+FORMULA_FONTSIZE = 10
+
+# Use the FORMULA_TRANSPARENT tag to determine whether or not the images
+# generated for formulas are transparent PNGs. Transparent PNGs are not
+# supported properly for IE 6.0, but are supported on all modern browsers.
+#
+# Note that when changing this option you need to delete any form_*.png files in
+# the HTML output directory before the changes have effect.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+FORMULA_TRANSPARENT = YES
+
+# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands
+# to create new LaTeX commands to be used in formulas as building blocks. See
+# the section "Including formulas" for details.
+
+FORMULA_MACROFILE =
+
+# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see
+# https://www.mathjax.org) which uses client side JavaScript for the rendering
+# instead of using pre-rendered bitmaps. Use this if you do not have LaTeX
+# installed or if you want to formulas look prettier in the HTML output. When
+# enabled you may also need to install MathJax separately and configure the path
+# to it using the MATHJAX_RELPATH option.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+USE_MATHJAX = YES
+
+# When MathJax is enabled you can set the default output format to be used for
+# the MathJax output. See the MathJax site (see:
+# http://docs.mathjax.org/en/latest/output.html) for more details.
+# Possible values are: HTML-CSS (which is slower, but has the best
+# compatibility), NativeMML (i.e. MathML) and SVG.
+# The default value is: HTML-CSS.
+# This tag requires that the tag USE_MATHJAX is set to YES.
+
+MATHJAX_FORMAT = HTML-CSS
+
+# When MathJax is enabled you need to specify the location relative to the HTML
+# output directory using the MATHJAX_RELPATH option. The destination directory
+# should contain the MathJax.js script. For instance, if the mathjax directory
+# is located at the same level as the HTML output directory, then
+# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax
+# Content Delivery Network so you can quickly see the result without installing
+# MathJax. However, it is strongly recommended to install a local copy of
+# MathJax from https://www.mathjax.org before deployment.
+# The default value is: https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/.
+# This tag requires that the tag USE_MATHJAX is set to YES.
+
+MATHJAX_RELPATH = https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/
+
+# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax
+# extension names that should be enabled during MathJax rendering. For example
+# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols
+# This tag requires that the tag USE_MATHJAX is set to YES.
+
+MATHJAX_EXTENSIONS =
+
+# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces
+# of code that will be used on startup of the MathJax code. See the MathJax site
+# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an
+# example see the documentation.
+# This tag requires that the tag USE_MATHJAX is set to YES.
+
+MATHJAX_CODEFILE =
+
+# When the SEARCHENGINE tag is enabled doxygen will generate a search box for
+# the HTML output. The underlying search engine uses javascript and DHTML and
+# should work on any modern browser. Note that when using HTML help
+# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET)
+# there is already a search function so this one should typically be disabled.
+# For large projects the javascript based search engine can be slow, then
+# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to
+# search using the keyboard; to jump to the search box use <access key> + S
+# (what the <access key> is depends on the OS and browser, but it is typically
+# <CTRL>, <ALT>/<option>, or both). Inside the search box use the <cursor down
+# key> to jump into the search results window, the results can be navigated
+# using the <cursor keys>. Press <Enter> to select an item or <escape> to cancel
+# the search. The filter options can be selected when the cursor is inside the
+# search box by pressing <Shift>+<cursor down>. Also here use the <cursor keys>
+# to select a filter and <Enter> or <escape> to activate or cancel the filter
+# option.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+SEARCHENGINE = YES
+
+# When the SERVER_BASED_SEARCH tag is enabled the search engine will be
+# implemented using a web server instead of a web client using JavaScript. There
+# are two flavors of web server based searching depending on the EXTERNAL_SEARCH
+# setting. When disabled, doxygen will generate a PHP script for searching and
+# an index file used by the script. When EXTERNAL_SEARCH is enabled the indexing
+# and searching needs to be provided by external tools. See the section
+# "External Indexing and Searching" for details.
+# The default value is: NO.
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+SERVER_BASED_SEARCH = NO
+
+# When EXTERNAL_SEARCH tag is enabled doxygen will no longer generate the PHP
+# script for searching. Instead the search results are written to an XML file
+# which needs to be processed by an external indexer. Doxygen will invoke an
+# external search engine pointed to by the SEARCHENGINE_URL option to obtain the
+# search results.
+#
+# Doxygen ships with an example indexer (doxyindexer) and search engine
+# (doxysearch.cgi) which are based on the open source search engine library
+# Xapian (see: https://xapian.org/).
+#
+# See the section "External Indexing and Searching" for details.
+# The default value is: NO.
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+EXTERNAL_SEARCH = NO
+
+# The SEARCHENGINE_URL should point to a search engine hosted by a web server
+# which will return the search results when EXTERNAL_SEARCH is enabled.
+#
+# Doxygen ships with an example indexer (doxyindexer) and search engine
+# (doxysearch.cgi) which are based on the open source search engine library
+# Xapian (see: https://xapian.org/). See the section "External Indexing and
+# Searching" for details.
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+SEARCHENGINE_URL =
+
+# When SERVER_BASED_SEARCH and EXTERNAL_SEARCH are both enabled the unindexed
+# search data is written to a file for indexing by an external tool. With the
+# SEARCHDATA_FILE tag the name of this file can be specified.
+# The default file is: searchdata.xml.
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+SEARCHDATA_FILE = searchdata.xml
+
+# When SERVER_BASED_SEARCH and EXTERNAL_SEARCH are both enabled the
+# EXTERNAL_SEARCH_ID tag can be used as an identifier for the project. This is
+# useful in combination with EXTRA_SEARCH_MAPPINGS to search through multiple
+# projects and redirect the results back to the right project.
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+EXTERNAL_SEARCH_ID =
+
+# The EXTRA_SEARCH_MAPPINGS tag can be used to enable searching through doxygen
+# projects other than the one defined by this configuration file, but that are
+# all added to the same external search index. Each project needs to have a
+# unique id set via EXTERNAL_SEARCH_ID. The search mapping then maps the id of
+# to a relative location where the documentation can be found. The format is:
+# EXTRA_SEARCH_MAPPINGS = tagname1=loc1 tagname2=loc2 ...
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+EXTRA_SEARCH_MAPPINGS =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the LaTeX output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_LATEX tag is set to YES, doxygen will generate LaTeX output.
+# The default value is: YES.
+
+GENERATE_LATEX = YES
+
+# The LATEX_OUTPUT tag is used to specify where the LaTeX docs will be put. If a
+# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
+# it.
+# The default directory is: latex.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_OUTPUT = latex
+
+# The LATEX_CMD_NAME tag can be used to specify the LaTeX command name to be
+# invoked.
+#
+# Note that when not enabling USE_PDFLATEX the default is latex when enabling
+# USE_PDFLATEX the default is pdflatex and when in the later case latex is
+# chosen this is overwritten by pdflatex. For specific output languages the
+# default can have been set differently, this depends on the implementation of
+# the output language.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_CMD_NAME =
+
+# The MAKEINDEX_CMD_NAME tag can be used to specify the command name to generate
+# index for LaTeX.
+# Note: This tag is used in the Makefile / make.bat.
+# See also: LATEX_MAKEINDEX_CMD for the part in the generated output file
+# (.tex).
+# The default file is: makeindex.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+MAKEINDEX_CMD_NAME = makeindex
+
+# The LATEX_MAKEINDEX_CMD tag can be used to specify the command name to
+# generate index for LaTeX. In case there is no backslash (\) as first character
+# it will be automatically added in the LaTeX code.
+# Note: This tag is used in the generated output file (.tex).
+# See also: MAKEINDEX_CMD_NAME for the part in the Makefile / make.bat.
+# The default value is: makeindex.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_MAKEINDEX_CMD = makeindex
+
+# If the COMPACT_LATEX tag is set to YES, doxygen generates more compact LaTeX
+# documents. This may be useful for small projects and may help to save some
+# trees in general.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+COMPACT_LATEX = NO
+
+# The PAPER_TYPE tag can be used to set the paper type that is used by the
+# printer.
+# Possible values are: a4 (210 x 297 mm), letter (8.5 x 11 inches), legal (8.5 x
+# 14 inches) and executive (7.25 x 10.5 inches).
+# The default value is: a4.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+PAPER_TYPE = a4
+
+# The EXTRA_PACKAGES tag can be used to specify one or more LaTeX package names
+# that should be included in the LaTeX output. The package can be specified just
+# by its name or with the correct syntax as to be used with the LaTeX
+# \usepackage command. To get the times font for instance you can specify :
+# EXTRA_PACKAGES=times or EXTRA_PACKAGES={times}
+# To use the option intlimits with the amsmath package you can specify:
+# EXTRA_PACKAGES=[intlimits]{amsmath}
+# If left blank no extra packages will be included.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+EXTRA_PACKAGES =
+
+# The LATEX_HEADER tag can be used to specify a personal LaTeX header for the
+# generated LaTeX document. The header should contain everything until the first
+# chapter. If it is left blank doxygen will generate a standard header. See
+# section "Doxygen usage" for information on how to let doxygen write the
+# default header to a separate file.
+#
+# Note: Only use a user-defined header if you know what you are doing! The
+# following commands have a special meaning inside the header: $title,
+# $datetime, $date, $doxygenversion, $projectname, $projectnumber,
+# $projectbrief, $projectlogo. Doxygen will replace $title with the empty
+# string, for the replacement values of the other commands the user is referred
+# to HTML_HEADER.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_HEADER =
+
+# The LATEX_FOOTER tag can be used to specify a personal LaTeX footer for the
+# generated LaTeX document. The footer should contain everything after the last
+# chapter. If it is left blank doxygen will generate a standard footer. See
+# LATEX_HEADER for more information on how to generate a default footer and what
+# special commands can be used inside the footer.
+#
+# Note: Only use a user-defined footer if you know what you are doing!
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_FOOTER =
+
+# The LATEX_EXTRA_STYLESHEET tag can be used to specify additional user-defined
+# LaTeX style sheets that are included after the standard style sheets created
+# by doxygen. Using this option one can overrule certain style aspects. Doxygen
+# will copy the style sheet files to the output directory.
+# Note: The order of the extra style sheet files is of importance (e.g. the last
+# style sheet in the list overrules the setting of the previous ones in the
+# list).
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_EXTRA_STYLESHEET =
+
+# The LATEX_EXTRA_FILES tag can be used to specify one or more extra images or
+# other source files which should be copied to the LATEX_OUTPUT output
+# directory. Note that the files will be copied as-is; there are no commands or
+# markers available.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_EXTRA_FILES =
+
+# If the PDF_HYPERLINKS tag is set to YES, the LaTeX that is generated is
+# prepared for conversion to PDF (using ps2pdf or pdflatex). The PDF file will
+# contain links (just like the HTML output) instead of page references. This
+# makes the output suitable for online browsing using a PDF viewer.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+PDF_HYPERLINKS = YES
+
+# If the USE_PDFLATEX tag is set to YES, doxygen will use pdflatex to generate
+# the PDF file directly from the LaTeX files. Set this option to YES, to get a
+# higher quality PDF documentation.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+USE_PDFLATEX = YES
+
+# If the LATEX_BATCHMODE tag is set to YES, doxygen will add the \batchmode
+# command to the generated LaTeX files. This will instruct LaTeX to keep running
+# if errors occur, instead of asking the user for help. This option is also used
+# when generating formulas in HTML.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_BATCHMODE = NO
+
+# If the LATEX_HIDE_INDICES tag is set to YES then doxygen will not include the
+# index chapters (such as File Index, Compound Index, etc.) in the output.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_HIDE_INDICES = NO
+
+# If the LATEX_SOURCE_CODE tag is set to YES then doxygen will include source
+# code with syntax highlighting in the LaTeX output.
+#
+# Note that which sources are shown also depends on other settings such as
+# SOURCE_BROWSER.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_SOURCE_CODE = NO
+
+# The LATEX_BIB_STYLE tag can be used to specify the style to use for the
+# bibliography, e.g. plainnat, or ieeetr. See
+# https://en.wikipedia.org/wiki/BibTeX and \cite for more info.
+# The default value is: plain.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_BIB_STYLE = plain
+
+# If the LATEX_TIMESTAMP tag is set to YES then the footer of each generated
+# page will contain the date and time when the page was generated. Setting this
+# to NO can help when comparing the output of multiple runs.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_TIMESTAMP = NO
+
+# The LATEX_EMOJI_DIRECTORY tag is used to specify the (relative or absolute)
+# path from which the emoji images will be read. If a relative path is entered,
+# it will be relative to the LATEX_OUTPUT directory. If left blank the
+# LATEX_OUTPUT directory will be used.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_EMOJI_DIRECTORY =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the RTF output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_RTF tag is set to YES, doxygen will generate RTF output. The
+# RTF output is optimized for Word 97 and may not look too pretty with other RTF
+# readers/editors.
+# The default value is: NO.
+
+GENERATE_RTF = NO
+
+# The RTF_OUTPUT tag is used to specify where the RTF docs will be put. If a
+# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
+# it.
+# The default directory is: rtf.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_OUTPUT = rtf
+
+# If the COMPACT_RTF tag is set to YES, doxygen generates more compact RTF
+# documents. This may be useful for small projects and may help to save some
+# trees in general.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+COMPACT_RTF = NO
+
+# If the RTF_HYPERLINKS tag is set to YES, the RTF that is generated will
+# contain hyperlink fields. The RTF file will contain links (just like the HTML
+# output) instead of page references. This makes the output suitable for online
+# browsing using Word or some other Word compatible readers that support those
+# fields.
+#
+# Note: WordPad (write) and others do not support links.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_HYPERLINKS = NO
+
+# Load stylesheet definitions from file. Syntax is similar to doxygen's
+# configuration file, i.e. a series of assignments. You only have to provide
+# replacements, missing definitions are set to their default value.
+#
+# See also section "Doxygen usage" for information on how to generate the
+# default style sheet that doxygen normally uses.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_STYLESHEET_FILE =
+
+# Set optional variables used in the generation of an RTF document. Syntax is
+# similar to doxygen's configuration file. A template extensions file can be
+# generated using doxygen -e rtf extensionFile.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_EXTENSIONS_FILE =
+
+# If the RTF_SOURCE_CODE tag is set to YES then doxygen will include source code
+# with syntax highlighting in the RTF output.
+#
+# Note that which sources are shown also depends on other settings such as
+# SOURCE_BROWSER.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_SOURCE_CODE = NO
+
+#---------------------------------------------------------------------------
+# Configuration options related to the man page output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_MAN tag is set to YES, doxygen will generate man pages for
+# classes and files.
+# The default value is: NO.
+
+GENERATE_MAN = NO
+
+# The MAN_OUTPUT tag is used to specify where the man pages will be put. If a
+# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
+# it. A directory man3 will be created inside the directory specified by
+# MAN_OUTPUT.
+# The default directory is: man.
+# This tag requires that the tag GENERATE_MAN is set to YES.
+
+MAN_OUTPUT = man
+
+# The MAN_EXTENSION tag determines the extension that is added to the generated
+# man pages. In case the manual section does not start with a number, the number
+# 3 is prepended. The dot (.) at the beginning of the MAN_EXTENSION tag is
+# optional.
+# The default value is: .3.
+# This tag requires that the tag GENERATE_MAN is set to YES.
+
+MAN_EXTENSION = .3
+
+# The MAN_SUBDIR tag determines the name of the directory created within
+# MAN_OUTPUT in which the man pages are placed. If defaults to man followed by
+# MAN_EXTENSION with the initial . removed.
+# This tag requires that the tag GENERATE_MAN is set to YES.
+
+MAN_SUBDIR =
+
+# If the MAN_LINKS tag is set to YES and doxygen generates man output, then it
+# will generate one additional man file for each entity documented in the real
+# man page(s). These additional files only source the real man page, but without
+# them the man command would be unable to find the correct page.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_MAN is set to YES.
+
+MAN_LINKS = NO
+
+#---------------------------------------------------------------------------
+# Configuration options related to the XML output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_XML tag is set to YES, doxygen will generate an XML file that
+# captures the structure of the code including all documentation.
+# The default value is: NO.
+
+GENERATE_XML = NO
+
+# The XML_OUTPUT tag is used to specify where the XML pages will be put. If a
+# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
+# it.
+# The default directory is: xml.
+# This tag requires that the tag GENERATE_XML is set to YES.
+
+XML_OUTPUT = xml
+
+# If the XML_PROGRAMLISTING tag is set to YES, doxygen will dump the program
+# listings (including syntax highlighting and cross-referencing information) to
+# the XML output. Note that enabling this will significantly increase the size
+# of the XML output.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_XML is set to YES.
+
+XML_PROGRAMLISTING = YES
+
+# If the XML_NS_MEMB_FILE_SCOPE tag is set to YES, doxygen will include
+# namespace members in file scope as well, matching the HTML output.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_XML is set to YES.
+
+XML_NS_MEMB_FILE_SCOPE = NO
+
+#---------------------------------------------------------------------------
+# Configuration options related to the DOCBOOK output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_DOCBOOK tag is set to YES, doxygen will generate Docbook files
+# that can be used to generate PDF.
+# The default value is: NO.
+
+GENERATE_DOCBOOK = NO
+
+# The DOCBOOK_OUTPUT tag is used to specify where the Docbook pages will be put.
+# If a relative path is entered the value of OUTPUT_DIRECTORY will be put in
+# front of it.
+# The default directory is: docbook.
+# This tag requires that the tag GENERATE_DOCBOOK is set to YES.
+
+DOCBOOK_OUTPUT = docbook
+
+# If the DOCBOOK_PROGRAMLISTING tag is set to YES, doxygen will include the
+# program listings (including syntax highlighting and cross-referencing
+# information) to the DOCBOOK output. Note that enabling this will significantly
+# increase the size of the DOCBOOK output.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_DOCBOOK is set to YES.
+
+DOCBOOK_PROGRAMLISTING = NO
+
+#---------------------------------------------------------------------------
+# Configuration options for the AutoGen Definitions output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_AUTOGEN_DEF tag is set to YES, doxygen will generate an
+# AutoGen Definitions (see http://autogen.sourceforge.net/) file that captures
+# the structure of the code including all documentation. Note that this feature
+# is still experimental and incomplete at the moment.
+# The default value is: NO.
+
+GENERATE_AUTOGEN_DEF = NO
+
+#---------------------------------------------------------------------------
+# Configuration options related to the Perl module output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_PERLMOD tag is set to YES, doxygen will generate a Perl module
+# file that captures the structure of the code including all documentation.
+#
+# Note that this feature is still experimental and incomplete at the moment.
+# The default value is: NO.
+
+GENERATE_PERLMOD = NO
+
+# If the PERLMOD_LATEX tag is set to YES, doxygen will generate the necessary
+# Makefile rules, Perl scripts and LaTeX code to be able to generate PDF and DVI
+# output from the Perl module output.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_PERLMOD is set to YES.
+
+PERLMOD_LATEX = NO
+
+# If the PERLMOD_PRETTY tag is set to YES, the Perl module output will be nicely
+# formatted so it can be parsed by a human reader. This is useful if you want to
+# understand what is going on. On the other hand, if this tag is set to NO, the
+# size of the Perl module output will be much smaller and Perl will parse it
+# just the same.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_PERLMOD is set to YES.
+
+PERLMOD_PRETTY = YES
+
+# The names of the make variables in the generated doxyrules.make file are
+# prefixed with the string contained in PERLMOD_MAKEVAR_PREFIX. This is useful
+# so different doxyrules.make files included by the same Makefile don't
+# overwrite each other's variables.
+# This tag requires that the tag GENERATE_PERLMOD is set to YES.
+
+PERLMOD_MAKEVAR_PREFIX =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the preprocessor
+#---------------------------------------------------------------------------
+
+# If the ENABLE_PREPROCESSING tag is set to YES, doxygen will evaluate all
+# C-preprocessor directives found in the sources and include files.
+# The default value is: YES.
+
+ENABLE_PREPROCESSING = YES
+
+# If the MACRO_EXPANSION tag is set to YES, doxygen will expand all macro names
+# in the source code. If set to NO, only conditional compilation will be
+# performed. Macro expansion can be done in a controlled way by setting
+# EXPAND_ONLY_PREDEF to YES.
+# The default value is: NO.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+MACRO_EXPANSION = NO
+
+# If the EXPAND_ONLY_PREDEF and MACRO_EXPANSION tags are both set to YES then
+# the macro expansion is limited to the macros specified with the PREDEFINED and
+# EXPAND_AS_DEFINED tags.
+# The default value is: NO.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+EXPAND_ONLY_PREDEF = NO
+
+# If the SEARCH_INCLUDES tag is set to YES, the include files in the
+# INCLUDE_PATH will be searched if a #include is found.
+# The default value is: YES.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+SEARCH_INCLUDES = YES
+
+# The INCLUDE_PATH tag can be used to specify one or more directories that
+# contain include files that are not input files but should be processed by the
+# preprocessor.
+# This tag requires that the tag SEARCH_INCLUDES is set to YES.
+
+INCLUDE_PATH =
+
+# You can use the INCLUDE_FILE_PATTERNS tag to specify one or more wildcard
+# patterns (like *.h and *.hpp) to filter out the header-files in the
+# directories. If left blank, the patterns specified with FILE_PATTERNS will be
+# used.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+INCLUDE_FILE_PATTERNS =
+
+# The PREDEFINED tag can be used to specify one or more macro names that are
+# defined before the preprocessor is started (similar to the -D option of e.g.
+# gcc). The argument of the tag is a list of macros of the form: name or
+# name=definition (no spaces). If the definition and the "=" are omitted, "=1"
+# is assumed. To prevent a macro definition from being undefined via #undef or
+# recursively expanded use the := operator instead of the = operator.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+PREDEFINED =
+
+# If the MACRO_EXPANSION and EXPAND_ONLY_PREDEF tags are set to YES then this
+# tag can be used to specify a list of macro names that should be expanded. The
+# macro definition that is found in the sources will be used. Use the PREDEFINED
+# tag if you want to use a different macro definition that overrules the
+# definition found in the source code.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+EXPAND_AS_DEFINED =
+
+# If the SKIP_FUNCTION_MACROS tag is set to YES then doxygen's preprocessor will
+# remove all references to function-like macros that are alone on a line, have
+# an all uppercase name, and do not end with a semicolon. Such function macros
+# are typically used for boiler-plate code, and will confuse the parser if not
+# removed.
+# The default value is: YES.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+SKIP_FUNCTION_MACROS = YES
+
+#---------------------------------------------------------------------------
+# Configuration options related to external references
+#---------------------------------------------------------------------------
+
+# The TAGFILES tag can be used to specify one or more tag files. For each tag
+# file the location of the external documentation should be added. The format of
+# a tag file without this location is as follows:
+# TAGFILES = file1 file2 ...
+# Adding location for the tag files is done as follows:
+# TAGFILES = file1=loc1 "file2 = loc2" ...
+# where loc1 and loc2 can be relative or absolute paths or URLs. See the
+# section "Linking to external documentation" for more information about the use
+# of tag files.
+# Note: Each tag file must have a unique name (where the name does NOT include
+# the path). If a tag file is not located in the directory in which doxygen is
+# run, you must also specify the path to the tagfile here.
+
+TAGFILES =
+
+# When a file name is specified after GENERATE_TAGFILE, doxygen will create a
+# tag file that is based on the input files it reads. See section "Linking to
+# external documentation" for more information about the usage of tag files.
+
+GENERATE_TAGFILE =
+
+# If the ALLEXTERNALS tag is set to YES, all external class will be listed in
+# the class index. If set to NO, only the inherited external classes will be
+# listed.
+# The default value is: NO.
+
+ALLEXTERNALS = NO
+
+# If the EXTERNAL_GROUPS tag is set to YES, all external groups will be listed
+# in the modules index. If set to NO, only the current project's groups will be
+# listed.
+# The default value is: YES.
+
+EXTERNAL_GROUPS = YES
+
+# If the EXTERNAL_PAGES tag is set to YES, all external pages will be listed in
+# the related pages index. If set to NO, only the current project's pages will
+# be listed.
+# The default value is: YES.
+
+EXTERNAL_PAGES = YES
+
+#---------------------------------------------------------------------------
+# Configuration options related to the dot tool
+#---------------------------------------------------------------------------
+
+# If the CLASS_DIAGRAMS tag is set to YES, doxygen will generate a class diagram
+# (in HTML and LaTeX) for classes with base or super classes. Setting the tag to
+# NO turns the diagrams off. Note that this option also works with HAVE_DOT
+# disabled, but it is recommended to install and use dot, since it yields more
+# powerful graphs.
+# The default value is: YES.
+
+CLASS_DIAGRAMS = YES
+
+# You can include diagrams made with dia in doxygen documentation. Doxygen will
+# then run dia to produce the diagram and insert it in the documentation. The
+# DIA_PATH tag allows you to specify the directory where the dia binary resides.
+# If left empty dia is assumed to be found in the default search path.
+
+DIA_PATH =
+
+# If set to YES the inheritance and collaboration graphs will hide inheritance
+# and usage relations if the target is undocumented or is not a class.
+# The default value is: YES.
+
+HIDE_UNDOC_RELATIONS = YES
+
+# If you set the HAVE_DOT tag to YES then doxygen will assume the dot tool is
+# available from the path. This tool is part of Graphviz (see:
+# http://www.graphviz.org/), a graph visualization toolkit from AT&T and Lucent
+# Bell Labs. The other options in this section have no effect if this option is
+# set to NO
+# The default value is: YES.
+
+HAVE_DOT = YES
+
+# The DOT_NUM_THREADS specifies the number of dot invocations doxygen is allowed
+# to run in parallel. When set to 0 doxygen will base this on the number of
+# processors available in the system. You can set it explicitly to a value
+# larger than 0 to get control over the balance between CPU load and processing
+# speed.
+# Minimum value: 0, maximum value: 32, default value: 0.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_NUM_THREADS = 0
+
+# When you want a differently looking font in the dot files that doxygen
+# generates you can specify the font name using DOT_FONTNAME. You need to make
+# sure dot is able to find the font, which can be done by putting it in a
+# standard location or by setting the DOTFONTPATH environment variable or by
+# setting DOT_FONTPATH to the directory containing the font.
+# The default value is: Helvetica.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_FONTNAME = Helvetica
+
+# The DOT_FONTSIZE tag can be used to set the size (in points) of the font of
+# dot graphs.
+# Minimum value: 4, maximum value: 24, default value: 10.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_FONTSIZE = 10
+
+# By default doxygen will tell dot to use the default font as specified with
+# DOT_FONTNAME. If you specify a different font using DOT_FONTNAME you can set
+# the path where dot can find it using this tag.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_FONTPATH =
+
+# If the CLASS_GRAPH tag is set to YES then doxygen will generate a graph for
+# each documented class showing the direct and indirect inheritance relations.
+# Setting this tag to YES will force the CLASS_DIAGRAMS tag to NO.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+CLASS_GRAPH = YES
+
+# If the COLLABORATION_GRAPH tag is set to YES then doxygen will generate a
+# graph for each documented class showing the direct and indirect implementation
+# dependencies (inheritance, containment, and class references variables) of the
+# class with other documented classes.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+COLLABORATION_GRAPH = YES
+
+# If the GROUP_GRAPHS tag is set to YES then doxygen will generate a graph for
+# groups, showing the direct groups dependencies.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+GROUP_GRAPHS = YES
+
+# If the UML_LOOK tag is set to YES, doxygen will generate inheritance and
+# collaboration diagrams in a style similar to the OMG's Unified Modeling
+# Language.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+UML_LOOK = NO
+
+# If the UML_LOOK tag is enabled, the fields and methods are shown inside the
+# class node. If there are many fields or methods and many nodes the graph may
+# become too big to be useful. The UML_LIMIT_NUM_FIELDS threshold limits the
+# number of items for each type to make the size more manageable. Set this to 0
+# for no limit. Note that the threshold may be exceeded by 50% before the limit
+# is enforced. So when you set the threshold to 10, up to 15 fields may appear,
+# but if the number exceeds 15, the total amount of fields shown is limited to
+# 10.
+# Minimum value: 0, maximum value: 100, default value: 10.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+UML_LIMIT_NUM_FIELDS = 10
+
+# If the TEMPLATE_RELATIONS tag is set to YES then the inheritance and
+# collaboration graphs will show the relations between templates and their
+# instances.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+TEMPLATE_RELATIONS = NO
+
+# If the INCLUDE_GRAPH, ENABLE_PREPROCESSING and SEARCH_INCLUDES tags are set to
+# YES then doxygen will generate a graph for each documented file showing the
+# direct and indirect include dependencies of the file with other documented
+# files.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+INCLUDE_GRAPH = YES
+
+# If the INCLUDED_BY_GRAPH, ENABLE_PREPROCESSING and SEARCH_INCLUDES tags are
+# set to YES then doxygen will generate a graph for each documented file showing
+# the direct and indirect include dependencies of the file with other documented
+# files.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+INCLUDED_BY_GRAPH = YES
+
+# If the CALL_GRAPH tag is set to YES then doxygen will generate a call
+# dependency graph for every global function or class method.
+#
+# Note that enabling this option will significantly increase the time of a run.
+# So in most cases it will be better to enable call graphs for selected
+# functions only using the \callgraph command. Disabling a call graph can be
+# accomplished by means of the command \hidecallgraph.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+CALL_GRAPH = NO
+
+# If the CALLER_GRAPH tag is set to YES then doxygen will generate a caller
+# dependency graph for every global function or class method.
+#
+# Note that enabling this option will significantly increase the time of a run.
+# So in most cases it will be better to enable caller graphs for selected
+# functions only using the \callergraph command. Disabling a caller graph can be
+# accomplished by means of the command \hidecallergraph.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+CALLER_GRAPH = NO
+
+# If the GRAPHICAL_HIERARCHY tag is set to YES then doxygen will graphical
+# hierarchy of all classes instead of a textual one.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+GRAPHICAL_HIERARCHY = YES
+
+# If the DIRECTORY_GRAPH tag is set to YES then doxygen will show the
+# dependencies a directory has on other directories in a graphical way. The
+# dependency relations are determined by the #include relations between the
+# files in the directories.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DIRECTORY_GRAPH = YES
+
+# The DOT_IMAGE_FORMAT tag can be used to set the image format of the images
+# generated by dot. For an explanation of the image formats see the section
+# output formats in the documentation of the dot tool (Graphviz (see:
+# http://www.graphviz.org/)).
+# Note: If you choose svg you need to set HTML_FILE_EXTENSION to xhtml in order
+# to make the SVG files visible in IE 9+ (other browsers do not have this
+# requirement).
+# Possible values are: png, png:cairo, png:cairo:cairo, png:cairo:gd, png:gd,
+# png:gd:gd, jpg, jpg:cairo, jpg:cairo:gd, jpg:gd, jpg:gd:gd, gif, gif:cairo,
+# gif:cairo:gd, gif:gd, gif:gd:gd, svg, png:gd, png:gd:gd, png:cairo,
+# png:cairo:gd, png:cairo:cairo, png:cairo:gdiplus, png:gdiplus and
+# png:gdiplus:gdiplus.
+# The default value is: png.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_IMAGE_FORMAT = png
+
+# If DOT_IMAGE_FORMAT is set to svg, then this option can be set to YES to
+# enable generation of interactive SVG images that allow zooming and panning.
+#
+# Note that this requires a modern browser other than Internet Explorer. Tested
+# and working are Firefox, Chrome, Safari, and Opera.
+# Note: For IE 9+ you need to set HTML_FILE_EXTENSION to xhtml in order to make
+# the SVG files visible. Older versions of IE do not have SVG support.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+INTERACTIVE_SVG = NO
+
+# The DOT_PATH tag can be used to specify the path where the dot tool can be
+# found. If left blank, it is assumed the dot tool can be found in the path.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_PATH =
+
+# The DOTFILE_DIRS tag can be used to specify one or more directories that
+# contain dot files that are included in the documentation (see the \dotfile
+# command).
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOTFILE_DIRS =
+
+# The MSCFILE_DIRS tag can be used to specify one or more directories that
+# contain msc files that are included in the documentation (see the \mscfile
+# command).
+
+MSCFILE_DIRS =
+
+# The DIAFILE_DIRS tag can be used to specify one or more directories that
+# contain dia files that are included in the documentation (see the \diafile
+# command).
+
+DIAFILE_DIRS =
+
+# When using plantuml, the PLANTUML_JAR_PATH tag should be used to specify the
+# path where java can find the plantuml.jar file. If left blank, it is assumed
+# PlantUML is not used or called during a preprocessing step. Doxygen will
+# generate a warning when it encounters a \startuml command in this case and
+# will not generate output for the diagram.
+
+PLANTUML_JAR_PATH =
+
+# When using plantuml, the PLANTUML_CFG_FILE tag can be used to specify a
+# configuration file for plantuml.
+
+PLANTUML_CFG_FILE =
+
+# When using plantuml, the specified paths are searched for files specified by
+# the !include statement in a plantuml block.
+
+PLANTUML_INCLUDE_PATH =
+
+# The DOT_GRAPH_MAX_NODES tag can be used to set the maximum number of nodes
+# that will be shown in the graph. If the number of nodes in a graph becomes
+# larger than this value, doxygen will truncate the graph, which is visualized
+# by representing a node as a red box. Note that doxygen if the number of direct
+# children of the root node in a graph is already larger than
+# DOT_GRAPH_MAX_NODES then the graph will not be shown at all. Also note that
+# the size of a graph can be further restricted by MAX_DOT_GRAPH_DEPTH.
+# Minimum value: 0, maximum value: 10000, default value: 50.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_GRAPH_MAX_NODES = 50
+
+# The MAX_DOT_GRAPH_DEPTH tag can be used to set the maximum depth of the graphs
+# generated by dot. A depth value of 3 means that only nodes reachable from the
+# root by following a path via at most 3 edges will be shown. Nodes that lay
+# further from the root node will be omitted. Note that setting this option to 1
+# or 2 may greatly reduce the computation time needed for large code bases. Also
+# note that the size of a graph can be further restricted by
+# DOT_GRAPH_MAX_NODES. Using a depth of 0 means no depth restriction.
+# Minimum value: 0, maximum value: 1000, default value: 0.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+MAX_DOT_GRAPH_DEPTH = 0
+
+# Set the DOT_TRANSPARENT tag to YES to generate images with a transparent
+# background. This is disabled by default, because dot on Windows does not seem
+# to support this out of the box.
+#
+# Warning: Depending on the platform used, enabling this option may lead to
+# badly anti-aliased labels on the edges of a graph (i.e. they become hard to
+# read).
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_TRANSPARENT = NO
+
+# Set the DOT_MULTI_TARGETS tag to YES to allow dot to generate multiple output
+# files in one run (i.e. multiple -o and -T options on the command line). This
+# makes dot run faster, but since only newer versions of dot (>1.8.10) support
+# this, this feature is disabled by default.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_MULTI_TARGETS = NO
+
+# If the GENERATE_LEGEND tag is set to YES doxygen will generate a legend page
+# explaining the meaning of the various boxes and arrows in the dot generated
+# graphs.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+GENERATE_LEGEND = YES
+
+# If the DOT_CLEANUP tag is set to YES, doxygen will remove the intermediate dot
+# files that are used to generate the various graphs.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_CLEANUP = YES
diff --git a/ggml/src/ggml-cann/acl_tensor.cpp b/ggml/src/ggml-cann/acl_tensor.cpp
new file mode 100644
index 00000000..960ce9a0
--- /dev/null
+++ b/ggml/src/ggml-cann/acl_tensor.cpp
@@ -0,0 +1,198 @@
+/*
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#include "acl_tensor.h"
+
+#include <algorithm>
+#include <cstring>
+
+aclDataType ggml_cann_type_mapping(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_F32:
+ return ACL_FLOAT;
+ case GGML_TYPE_F16:
+ return ACL_FLOAT16;
+ case GGML_TYPE_I8:
+ return ACL_INT8;
+ case GGML_TYPE_I16:
+ return ACL_INT16;
+ case GGML_TYPE_I32:
+ return ACL_INT32;
+ default:
+ return ACL_DT_UNDEFINED;
+ }
+ return ACL_DT_UNDEFINED;
+}
+
+aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
+ size_t* nb, int64_t dims, aclFormat format,
+ size_t offset) {
+ // If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be
+ // added.
+ int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2];
+
+ int64_t acl_storage_len = 0;
+ if (ne == nullptr) {
+ acl_storage_len = ggml_nbytes(tensor);
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ acl_ne[i] = tensor->ne[i];
+ // The step size of acl is in elements.
+ acl_stride[i] = tensor->nb[i] / ggml_element_size(tensor);
+ }
+ } else {
+ // With bcast
+ for (int i = 0; i < dims; i++) {
+ acl_storage_len += (ne[i] - 1) * nb[i];
+ acl_ne[i] = ne[i];
+ acl_stride[i] = nb[i] / ggml_element_size(tensor);
+ }
+ }
+
+ // Reverse ne and stride.
+ int64_t final_dims = (dims == 0 ? GGML_MAX_DIMS : dims);
+ std::reverse(acl_ne, acl_ne + final_dims);
+ std::reverse(acl_stride, acl_stride + final_dims);
+
+ aclTensor* acl_tensor = aclCreateTensor(
+ acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
+ offset / ggml_element_size(tensor), format, &acl_storage_len, 1,
+ tensor->data);
+
+ return acl_tensor;
+}
+
+bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1) {
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ if (t1->ne[i] != t0->ne[i] && t1->ne[i] != 1) {
+ return true;
+ }
+ }
+ return false;
+}
+
+aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
+ size_t type_size, int64_t* ne, size_t* nb,
+ int64_t dims, aclFormat format,
+ size_t offset) {
+ int64_t tmp_ne[GGML_MAX_DIMS * 2];
+ int64_t tmp_stride[GGML_MAX_DIMS * 2];
+
+ memcpy(tmp_ne, ne, dims * sizeof(int64_t));
+ for (int i = 0; i < dims; i++) {
+ tmp_stride[i] = nb[i] / type_size;
+ }
+
+ std::reverse(tmp_ne, tmp_ne + dims);
+ std::reverse(tmp_stride, tmp_stride + dims);
+
+ int64_t acl_storage_len = 0;
+ for (int i = 0; i < dims; i++) {
+ acl_storage_len += (ne[i] - 1) * nb[i];
+ }
+
+ aclTensor* acl_tensor =
+ aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size,
+ format, &acl_storage_len, 1, data_ptr);
+
+ return acl_tensor;
+}
+
+int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0,
+ const ggml_tensor* src1,
+ int64_t* bcast_src0_ne,
+ int64_t* bcast_src1_ne, size_t* bcast_src0_nb,
+ size_t* bcast_src1_nb) {
+ GGML_ASSERT(ggml_can_repeat(src1, src0));
+ int bcast_dim_cnt = 0;
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ int64_t nr = src0->ne[i] / src1->ne[i];
+ bcast_src0_ne[bcast_dim_cnt] = src0->ne[i] / nr;
+ bcast_src1_ne[bcast_dim_cnt] = src1->ne[i];
+ bcast_src0_nb[bcast_dim_cnt] = src0->nb[i];
+ bcast_src1_nb[bcast_dim_cnt] = src1->nb[i];
+ bcast_dim_cnt++;
+ if (nr != 1) {
+ // Need to add an extra dim.
+ bcast_src0_ne[bcast_dim_cnt] = nr;
+ bcast_src1_ne[bcast_dim_cnt] = 1;
+ bcast_src0_nb[bcast_dim_cnt] = bcast_src0_nb[bcast_dim_cnt - 1] *
+ bcast_src0_ne[bcast_dim_cnt - 1];
+ bcast_src1_nb[bcast_dim_cnt] = bcast_src1_nb[bcast_dim_cnt - 1] *
+ bcast_src1_ne[bcast_dim_cnt - 1];
+ bcast_dim_cnt++;
+ }
+ }
+ return bcast_dim_cnt;
+}
+
+int64_t ggml_cann_get_mulmat_bcast_shape(
+ const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne,
+ const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb,
+ int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne,
+ size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb) {
+ // input and dst shoule in same shape, except first two dims.
+ GGML_ASSERT(input_ne[2] == dst_ne[2]);
+ GGML_ASSERT(input_ne[3] == dst_ne[3]);
+
+ int bcast_dim_cnt = 0;
+
+ // For mul_mat, a dimension needs to be added before the dimension that
+ // weight needs to be expanded to satisfy the bcast rule of matrix
+ // multiplication.
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ int64_t nr = input_ne[i] / weight_ne[i];
+ // Do not use bcast in the first two dimensions because we only support
+ // the bcast batch dimension. Just copy them.
+ if (i < 2 || nr == 1) {
+ bcast_input_ne[bcast_dim_cnt] = input_ne[i];
+ bcast_weight_ne[bcast_dim_cnt] = weight_ne[i];
+ bcast_dst_ne[bcast_dim_cnt] = dst_ne[i];
+
+ bcast_input_nb[bcast_dim_cnt] = input_nb[i];
+ bcast_weight_nb[bcast_dim_cnt] = weight_nb[i];
+ bcast_dst_nb[bcast_dim_cnt] = dst_nb[i];
+ bcast_dim_cnt++;
+ } else {
+ // Need to add an extra dim.
+ bcast_input_ne[bcast_dim_cnt] = nr;
+ bcast_dst_ne[bcast_dim_cnt] = nr;
+ bcast_weight_ne[bcast_dim_cnt] = 1;
+ bcast_input_nb[bcast_dim_cnt] = input_nb[i];
+ bcast_dst_nb[bcast_dim_cnt] = dst_nb[i];
+ bcast_weight_nb[bcast_dim_cnt] = weight_nb[i];
+ bcast_dim_cnt++;
+
+ bcast_input_ne[bcast_dim_cnt] = input_ne[i] / nr;
+ bcast_dst_ne[bcast_dim_cnt] = dst_ne[i] / nr;
+ bcast_weight_ne[bcast_dim_cnt] = weight_ne[i];
+ bcast_input_nb[bcast_dim_cnt] = bcast_input_nb[bcast_dim_cnt - 1] *
+ bcast_input_ne[bcast_dim_cnt - 1];
+ bcast_dst_nb[bcast_dim_cnt] = bcast_dst_nb[bcast_dim_cnt - 1] *
+ bcast_dst_ne[bcast_dim_cnt - 1];
+ bcast_weight_nb[bcast_dim_cnt] =
+ bcast_weight_nb[bcast_dim_cnt - 1] *
+ bcast_weight_ne[bcast_dim_cnt - 1];
+ bcast_dim_cnt++;
+ }
+ }
+ return bcast_dim_cnt;
+}
diff --git a/ggml/src/ggml-cann/acl_tensor.h b/ggml/src/ggml-cann/acl_tensor.h
new file mode 100644
index 00000000..7d0bf04e
--- /dev/null
+++ b/ggml/src/ggml-cann/acl_tensor.h
@@ -0,0 +1,230 @@
+/*
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#ifndef CANN_ACL_TENSOR_H
+#define CANN_ACL_TENSOR_H
+
+#include <aclnn/aclnn_base.h>
+#include "common.h"
+
+/**
+ * @brief Maps a ggml_type to its corresponding aclDataType.
+ *
+ * @details This function takes a ggml_type as input and returns the corresponding
+ * aclDataType. It supports mapping for various ggml_types. If the input type
+ * does not match any of the predefined ggml_types, the function returns
+ * ACL_DT_UNDEFINED.
+ *
+ * @param type The ggml_type to be mapped.
+ * @return The corresponding aclDataType. If the input type is not recognized,
+ * ACL_DT_UNDEFINED is returned.
+ */
+aclDataType ggml_cann_type_mapping(ggml_type type);
+
+/**
+ * @brief Creates an ACL tensor from a ggml_tensor with optional shape.
+ *
+ * @details This function creates an ACL tensor based on the properties of the
+ * provided ggml_tensor. It supports customer shape by adjusting dimensions
+ * and strides accordingly. If customer shape is applied, additional
+ * dimensions and strides are calculated based on the provided parameters.
+ *
+ * @param tensor Pointer to the ggml_tensor to be converted to ACL tensor.
+ * @param ne Pointer to an array containing dimensions. Defaults to nullptr
+ * if no customer shape is applied.
+ * @param nb Pointer to an array containing strides. Defaults to nullptr
+ * if no customer shape is applied.
+ * @param dims Number of dimensions in the tensor. Defaults to 0 if no customer
+ * shape is applied.
+ * @param format ACL tensor format. Defaults to ACL_FORMAT_ND.
+ * @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
+ * @return Pointer to the created ACL tensor.
+ */
+aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = nullptr,
+ size_t* nb = nullptr, int64_t dims = 0,
+ aclFormat format = ACL_FORMAT_ND,
+ size_t offset = 0);
+
+/**
+ * @brief Creates an ACL tensor from provided parameters.
+ *
+ * @details This function creates an ACL tensor using the provided data pointer,
+ * data type, dimensions, strides, format, offset, and additional parameters.
+ * It calculates necessary dimensions and strides based on the provided ne and nb
+ * arrays, adjusting them for the ACL tensor creation. The ACL storage length
+ * is also calculated based on the provided dimensions and strides.
+ *
+ * @param data_ptr Pointer to the data buffer for the ACL tensor.
+ * @param dtype ACL data type of the tensor.
+ * @param type_size Size of each element in the tensor data buffer.
+ * @param ne Pointer to an array containing tensor dimensions.
+ * @param nb Pointer to an array containing tensor strides.
+ * @param dims Number of dimensions of the tensor.
+ * @param format ACL tensor format. Defaults to ACL_FORMAT_ND.
+ * @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
+ * @return Pointer to the created ACL tensor.
+ */
+aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
+ size_t type_size, int64_t* ne, size_t* nb,
+ int64_t dims, aclFormat format = ACL_FORMAT_ND,
+ size_t offset = 0);
+
+/**
+ * @brief Checks if tensors require broadcasting based on their shapes.
+ *
+ * @details This function determines if two ggml_tensors need to be broadcasted for
+ * element-wise operations. Broadcasting is necessary if the shapes of the
+ * tensors are not identical and no dimension in either tensor equals 1.
+ *
+ * @param t0 Pointer to the first ggml_tensor.
+ * @param t1 Pointer to the second ggml_tensor.
+ * @return True if broadcasting is needed, False otherwise.
+ *
+ * @remarks This function iterates over the dimensions of t0 and t1. It checks if each
+ * dimension in t1 differs from t0's corresponding dimension and is not equal
+ * to 1. If such a dimension is found, broadcasting is required to align t1
+ * with t0 for element-wise operations.
+ */
+bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1);
+
+/**
+ * @brief Computes broadcast shapes and strides for two ggml_tensors.
+ *
+ * @details This function calculates the broadcast shapes and strides for two ggml_tensors,
+ * following the broadcasting rules similar to numpy. It adjusts dimensions and
+ * strides to ensure compatibility for element-wise operations where one tensor
+ * can be broadcasted to match the shape of another tensor.
+ *
+ * @param src0 Pointer to the first ggml_tensor.
+ * @param src1 Pointer to the second ggml_tensor.
+ * @param bcast_ne_src0 Output array to store broadcasted dimensions for src0.
+ * @param bcast_ne_src1 Output array to store broadcasted dimensions for src1.
+ * @param bcast_nb_src0 Output array to store broadcasted strides for src0.
+ * @param bcast_nb_src1 Output array to store broadcasted strides for src1.
+ * @return Number of dimensions in the broadcasted shape.
+ *
+ * @pre ggml_can_repeat(src1, src0) must return true, indicating src1 can be broadcasted
+ * to match src0.
+ *
+ * @remarks This function iterates over the dimensions of src0 and src1, calculating the
+ * necessary broadcast dimensions and strides. If a dimension requires broadcasting
+ * (i.e., its size in src1 is smaller than in src0), an additional dimension is
+ * added with size calculated to match src0's dimension. This adjustment ensures
+ * that src1 can be element-wise broadcasted to src0's shape.
+ *
+ * How it works:
+ *
+ * if dim0 has padding.
+ * a -> (2, 2) padding = 2
+ * a: [[1, 2, *, *]
+ * [2, 3, *, *]]
+ * nb = (8, 4, 2)
+ *
+ * if a should bcast with b -> (2, 4)
+ * b' -> (2, 2, 2)
+ * b : [[1, 2, 3, 4, *, *]
+ * [5, 6, 7, 8, *, *]]
+ * nb = (12, 6, 1)
+ *
+ * after bcast:
+ * a' -> (2, 1, 2)
+ * a': [[[1, 2], *, *]
+ * [[2, 3], *, *]]
+ * nb = (8, 4, 2, 1)
+ *
+ * b' : [[[1, 2], [3, 4], *, *]
+ * [[5, 6], [7, 8], *, *]]
+ * nb = (12, 6, 2, 1)
+ * \endcode
+ *
+ * dim1 in a inserted dim, should add nb for dim1,
+ * and all other nb moves to next in order.
+ */
+int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* src1,
+ int64_t* bcast_ne_src0, int64_t* bcast_ne_src1,
+ size_t* bcast_nb_src0, size_t* bcast_nb_src1);
+
+// Bcast macro to avoid duplicate code.
+#define BCAST_SHAPE(src0, src1) \
+ int64_t bcast_##src0##_ne[GGML_MAX_DIMS * 2]; \
+ int64_t bcast_##src1##_ne[GGML_MAX_DIMS * 2]; \
+ size_t bcast_##src0##_nb[GGML_MAX_DIMS * 2]; \
+ size_t bcast_##src1##_nb[GGML_MAX_DIMS * 2]; \
+ int64_t bcast_dims = ggml_cann_get_bcast_shape( \
+ src0, src1, bcast_##src0##_ne, bcast_##src1##_ne, bcast_##src0##_nb, \
+ bcast_##src1##_nb);
+
+#define BCAST_PARAM(tensor) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
+
+/**
+ * @brief Calculates broadcast shapes for matrix multiplication.
+ *
+ * @details This function computes the broadcast shapes required for matrix multiplication
+ * based on the input, weight, and destination tensor shapes. It ensures that the
+ * dimensions of weight tensors are expanded appropriately to satisfy matrix
+ * multiplication broadcast rules.
+ *
+ * @param input_ne Array containing the dimensions of the input tensor.
+ * @param weight_ne Array containing the dimensions of the weight tensor.
+ * @param dst_ne Array containing the dimensions of the destination tensor.
+ * @param input_nb Array containing the strides of the input tensor.
+ * @param weight_nb Array containing the strides of the weight tensor.
+ * @param dst_nb Array containing the strides of the destination tensor.
+ * @param bcast_input_ne Output array for broadcasted input tensor dimensions.
+ * @param bcast_weight_ne Output array for broadcasted weight tensor dimensions.
+ * @param bcast_dst_ne Output array for broadcasted destination tensor dimensions.
+ * @param bcast_input_nb Output array for broadcasted input tensor strides.
+ * @param bcast_weight_nb Output array for broadcasted weight tensor strides.
+ * @param bcast_dst_nb Output array for broadcasted destination tensor strides.
+ * @return The number of dimensions in the broadcasted tensors.
+ *
+ * @remarks This function iterates over the tensor dimensions and calculates the broadcast
+ * shapes needed for matrix multiplication. It ensures that dimensions where
+ * weight tensor requires expansion are appropriately handled to conform with
+ * broadcasting rules.
+ * @note compare with ggml_cann_get_bcast_shape, mul_mat broadcast need add this new dim
+ * before cast dim.
+ * @sa ggml_cann_get_bcast_shape
+ */
+int64_t ggml_cann_get_mulmat_bcast_shape(
+ const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne,
+ const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb,
+ int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne,
+ size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb);
+
+// Bcast macro to avoid duplicate code.
+#define BCAST_MUL_MAT_SHAPE(input, weight, dst) \
+ int64_t bcast_##input##_ne[GGML_MAX_DIMS * 2]; \
+ int64_t bcast_##weight##_ne[GGML_MAX_DIMS * 2]; \
+ int64_t bcast_##dst##_ne[GGML_MAX_DIMS * 2]; \
+ size_t bcast_##input##_nb[GGML_MAX_DIMS * 2]; \
+ size_t bcast_##weight##_nb[GGML_MAX_DIMS * 2]; \
+ size_t bcast_##dst##_nb[GGML_MAX_DIMS * 2]; \
+ int64_t bcast_dims = ggml_cann_get_mulmat_bcast_shape( \
+ input->ne, weight->ne, dst->ne, input->nb, weight->nb, dst->nb, \
+ bcast_##input##_ne, bcast_##weight##_ne, bcast_##dst##_ne, \
+ bcast_##input##_nb, bcast_##weight##_nb, bcast_##dst##_nb);
+
+#define BCAST_MUL_MAT_PARAM(tensor) \
+ bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
+
+#endif // CANN_ACL_TENSOR_H
diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp
new file mode 100644
index 00000000..a02efc82
--- /dev/null
+++ b/ggml/src/ggml-cann/aclnn_ops.cpp
@@ -0,0 +1,2944 @@
+/*
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#include "aclnn_ops.h"
+
+#include <aclnnop/aclnn_avgpool2d.h>
+#include <aclnnop/aclnn_cast.h>
+#include <aclnnop/aclnn_constant_pad_nd.h>
+#include <aclnnop/aclnn_copy.h>
+#include <aclnnop/aclnn_cos.h>
+#include <aclnnop/aclnn_exp.h>
+#include <aclnnop/aclnn_fill_scalar.h>
+#include <aclnnop/aclnn_group_norm.h>
+#include <aclnnop/aclnn_index_fill_tensor.h>
+#include <aclnnop/aclnn_layer_norm.h>
+#include <aclnnop/aclnn_matmul.h>
+#include <aclnnop/aclnn_max_pool.h>
+#include <aclnnop/aclnn_permute.h>
+#include <aclnnop/aclnn_pow_tensor_tensor.h>
+#include <aclnnop/aclnn_reduce_sum.h>
+#include <aclnnop/aclnn_repeat.h>
+#include <aclnnop/aclnn_repeat_interleave.h>
+#include <aclnnop/aclnn_roll.h>
+#include <aclnnop/aclnn_sin.h>
+#include <aclnnop/aclnn_softmax.h>
+#include <aclnnop/aclnn_tril.h>
+#include <aclnnop/aclnn_triu.h>
+#include <aclnnop/aclnn_upsample_nearest_2d.h>
+#include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h>
+#include <float.h>
+
+#include <cmath>
+#include <cstring>
+#include <exception>
+#include <vector>
+
+#include "kernels/ascendc_kernels.h"
+
+#define GGML_COMMON_DECL_C
+
+#include "../ggml-common.h"
+
+/**
+ * @brief Repeats elements of a tensor along each dimension according to the
+ * specified repeat array.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor to be repeated.
+ * @param acl_dst The destination tensor after repeating.
+ * @param repeat_array The array specifying the number of repetitions along each
+ * dimension.
+ */
+static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+ aclTensor* acl_dst, int64_t* repeat_array) {
+ // repeat tensor along each dim with repeat_array
+ aclIntArray* repeats = aclCreateIntArray(repeat_array, GGML_MAX_DIMS);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnRepeatGetWorkspaceSize(acl_src, repeats, acl_dst,
+ &workspaceSize, &executor));
+
+ if (workspaceSize > 0) {
+ // Memory from allocator will "free" immediately, and this memory
+ // will be alloced to other pointers, but it won't access before
+ // this async task end because all tasks in same stream will execute
+ // in queue.
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+ ACL_CHECK(
+ aclnnRepeat(workspaceAddr, workspaceSize, executor, ctx.stream()));
+ ACL_CHECK(aclDestroyIntArray(repeats));
+}
+
+void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+ GGML_ASSERT(ggml_can_repeat(src, dst));
+
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ int64_t repeatsArray[] = {dst->ne[3] / src->ne[3], dst->ne[2] / src->ne[2],
+ dst->ne[1] / src->ne[1], dst->ne[0] / src->ne[0]};
+
+ aclnn_repeat(ctx, acl_src, acl_dst, repeatsArray);
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+/**
+ * @brief Adds two tensors element-wise and stores the result in a destination
+ * tensor.
+ *
+ * This function performs the operation:
+ * \f[
+ * dst = acl\_src0 + alpha \times acl\_src1
+ * \f]
+ * where alpha is a scalar value and defaults to 1.0f.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src0 The first source tensor.
+ * @param acl_src1 The second source tensor.
+ * @param acl_dst The destination tensor where the result will be stored.
+ */
+static void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
+ aclTensor* acl_src1, aclTensor* acl_dst) {
+ aclScalar* alpha = nullptr;
+ float alphaValue = 1.0f;
+ alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnAddGetWorkspaceSize(acl_src0, acl_src1, alpha, acl_dst,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyScalar(alpha));
+}
+
+void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src0 = dst->src[0];
+ ggml_tensor* src1 = dst->src[1];
+ GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+ aclTensor* acl_src0;
+ aclTensor* acl_src1;
+ aclTensor* acl_dst;
+
+ // Need bcast
+ if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) {
+ BCAST_SHAPE(src0, src1)
+ acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0));
+ acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1));
+ acl_dst = ggml_cann_create_tensor(dst, BCAST_PARAM(src0));
+ } else {
+ acl_src0 = ggml_cann_create_tensor(src0);
+ acl_src1 = ggml_cann_create_tensor(src1);
+ acl_dst = ggml_cann_create_tensor(dst);
+ }
+
+ aclnn_add(ctx, acl_src0, acl_src1, acl_dst);
+
+ ACL_CHECK(aclDestroyTensor(acl_src0));
+ ACL_CHECK(aclDestroyTensor(acl_src1));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+
+ GGML_ASSERT(src->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ float negative_slope;
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
+ aclScalar* acl_negative_slope =
+ aclCreateScalar(&negative_slope, aclDataType::ACL_FLOAT);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnLeakyReluGetWorkspaceSize(
+ acl_src, acl_negative_slope, acl_dst, &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnLeakyRelu(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyScalar(acl_negative_slope));
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+/**
+ * @brief Concatenates a list of tensors along a specified dimension and stores
+ * the result in a destination tensor.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param tensorList The list of tensors to be concatenated.
+ * @param acl_dst The destination tensor where the concatenated result will be
+ * stored.
+ * @param concat_dim The dimension along which the tensors will be concatenated.
+ */
+static void aclnn_concat(ggml_backend_cann_context& ctx,
+ aclTensorList* tensorList, aclTensor* acl_dst,
+ int64_t concat_dim) {
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnCatGetWorkspaceSize(tensorList, concat_dim, acl_dst,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnCat(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src0 = dst->src[0];
+ ggml_tensor* src1 = dst->src[1];
+ aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
+ aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ int64_t concat_dim = 1;
+ aclTensor* tensors[] = {acl_src0, acl_src1};
+ aclTensorList* tensorList = aclCreateTensorList(tensors, 2);
+ aclnn_concat(ctx, tensorList, acl_dst, concat_dim);
+
+ ACL_CHECK(aclDestroyTensorList(tensorList));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+/**
+ * @brief Creates a tensor with values starting from `start`, incremented by
+ * `step`, and ending before `stop`.
+ *
+ * This function performs the operation:
+ * \f[
+ * \text {out }_{i+1}=\text {out }_i+\text {step}
+ * \f]
+ * the range is [start, stop).
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_dst The destination tensor where the values will be stored.
+ * @param start The starting value of the range.
+ * @param stop The ending value of the range (exclusive).
+ * @param step The step size between consecutive values.
+ * @param n_elements The number of elements in the destination tensor.
+ */
+static void aclnn_arange(ggml_backend_cann_context& ctx, aclTensor* acl_dst,
+ float start, float stop, float step,
+ int64_t n_elements) {
+ int64_t steps = (int64_t)std::ceil((stop - start) / step);
+ GGML_ASSERT(n_elements == steps);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ aclScalar* acl_start = aclCreateScalar(&start, aclDataType::ACL_FLOAT);
+ aclScalar* acl_end = aclCreateScalar(&stop, aclDataType::ACL_FLOAT);
+ aclScalar* acl_step = aclCreateScalar(&step, aclDataType::ACL_FLOAT);
+
+ ACL_CHECK(aclnnArangeGetWorkspaceSize(acl_start, acl_end, acl_step, acl_dst,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnArange(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyScalar(acl_start));
+ ACL_CHECK(aclDestroyScalar(acl_end));
+ ACL_CHECK(aclDestroyScalar(acl_step));
+}
+
+void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ int64_t n_elements = ggml_nelements(dst);
+ float start;
+ float stop;
+ float step;
+ memcpy(&start, (float*)dst->op_params + 0, sizeof(float));
+ memcpy(&stop, (float*)dst->op_params + 1, sizeof(float));
+ memcpy(&step, (float*)dst->op_params + 2, sizeof(float));
+
+ aclnn_arange(ctx, acl_dst, start, stop, step, n_elements);
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_sqr(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ dst->src[1] = dst->src[0];
+ ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
+}
+
+void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+ GGML_ASSERT(src->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ float min;
+ float max;
+ memcpy(&min, dst->op_params, sizeof(float));
+ memcpy(&max, (float*)dst->op_params + 1, sizeof(float));
+
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ aclScalar* acl_min = aclCreateScalar(&min, aclDataType::ACL_FLOAT);
+ aclScalar* acl_max = aclCreateScalar(&max, aclDataType::ACL_FLOAT);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnClampGetWorkspaceSize(acl_src, acl_min, acl_max, acl_dst,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnClamp(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyScalar(acl_min));
+ ACL_CHECK(aclDestroyScalar(acl_max));
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+
+ // scale factor
+ float v;
+ memcpy(&v, dst->op_params, sizeof(float));
+
+ aclScalar* scale = aclCreateScalar(&v, aclDataType::ACL_FLOAT);
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnMulsGetWorkspaceSize(acl_src, scale, acl_dst, &workspaceSize,
+ &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnMuls(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyScalar(scale));
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+ enum ggml_sort_order order = (enum ggml_sort_order)dst->op_params[0];
+
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+ ggml_cann_pool_alloc temp_buffer_allocator(
+ ctx.pool(), ggml_nelements(dst) * sizeof(int64_t));
+ void* buffer = temp_buffer_allocator.get();
+ aclTensor* tmp_tensor =
+ ggml_cann_create_tensor(buffer, ACL_INT64, ggml_type_size(dst->type),
+ dst->ne, dst->nb, GGML_MAX_DIMS);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnArgsortGetWorkspaceSize(
+ acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false), tmp_tensor,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnArgsort(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ workspaceSize = 0;
+ ACL_CHECK(aclnnCastGetWorkspaceSize(tmp_tensor,
+ ggml_cann_type_mapping(dst->type),
+ acl_dst, &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnCast(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(tmp_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ std::vector<int64_t> normData = {dst->ne[0]};
+ aclIntArray* norm = aclCreateIntArray(normData.data(), normData.size());
+ ACL_CHECK(aclnnLayerNormGetWorkspaceSize(acl_src, norm, nullptr, nullptr,
+ eps, acl_dst, nullptr, nullptr,
+ &workspaceSize, &executor));
+
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnLayerNorm(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyIntArray(norm));
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ const float eps = 1e-6f; // TODO: make this a parameter
+ int n_groups = dst->op_params[0];
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ int64_t N = src->ne[3];
+ int64_t C = src->ne[2];
+ int64_t HxW = src->ne[1] * src->ne[0];
+
+ size_t type_size = ggml_type_size(src->type);
+ int64_t ne[] = {n_groups, N};
+ size_t nb[] = {type_size, type_size * n_groups};
+ size_t n_bytes = N * n_groups;
+
+ ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes * 2);
+ void* buffer = temp_buffer_allocator.get();
+ aclTensor* acl_mean_out = ggml_cann_create_tensor(
+ buffer, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND);
+ aclTensor* acl_rstd_out = ggml_cann_create_tensor(
+ (char*)buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND);
+
+ ACL_CHECK(aclnnGroupNormGetWorkspaceSize(
+ acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, acl_dst,
+ acl_mean_out, acl_rstd_out, &workspaceSize, &executor));
+
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnGroupNorm(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+ ACL_CHECK(aclDestroyTensor(acl_mean_out));
+ ACL_CHECK(aclDestroyTensor(acl_rstd_out));
+}
+
+void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src0 = dst->src[0];
+ ggml_tensor* src1 = dst->src[1];
+
+ size_t nb1 = ((int32_t*)dst->op_params)[0];
+ size_t nb2 = ((int32_t*)dst->op_params)[1];
+ size_t nb3 = ((int32_t*)dst->op_params)[2];
+ size_t offset = ((int32_t*)dst->op_params)[3];
+ bool inplace = (bool)((int32_t*)dst->op_params)[4];
+
+ size_t param_nb[] = {ggml_element_size(src0), nb1, nb2, nb3};
+
+ aclTensor* acl_dst = ggml_cann_create_tensor(
+ dst, src1->ne, param_nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);
+ aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
+
+ aclScalar* alpha = nullptr;
+ float alphaValue = 1.0f;
+ alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ if (!inplace) {
+ size_t cpy_size = ggml_nbytes(dst);
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size,
+ ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
+ aclTensor* acl_src0 = ggml_cann_create_tensor(
+ src0, src1->ne, src0->nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);
+ ACL_CHECK(aclnnAddGetWorkspaceSize(acl_src0, acl_src1, alpha, acl_dst,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+ ACL_CHECK(
+ aclnnAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
+ ACL_CHECK(aclDestroyTensor(acl_src0));
+ } else {
+ ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, acl_src1, alpha,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+ ACL_CHECK(aclnnInplaceAdd(workspaceAddr, workspaceSize, executor,
+ ctx.stream()));
+ }
+
+ ACL_CHECK(aclDestroyTensor(acl_src1));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
+
+ GGML_ASSERT(dst->ne[0] == 1);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ int64_t reduce_dims_host[] = {3};
+ aclIntArray* reduce_dims = aclCreateIntArray(reduce_dims_host, 1);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnReduceSumGetWorkspaceSize(
+ acl_src, reduce_dims, true, ggml_cann_type_mapping(src->type), acl_dst,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnReduceSum(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
+ ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+ aclTensor* acl_src =
+ ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
+ aclTensor* acl_dst =
+ ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
+
+ std::vector<int64_t> output_size{dst->ne[1], dst->ne[0]};
+ auto output_size_array = aclCreateIntArray(output_size.data(), 2);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnUpsampleNearest2dGetWorkspaceSize(
+ acl_src, output_size_array, acl_dst, &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnUpsampleNearest2d(workspaceAddr, workspaceSize, executor,
+ ctx.stream()));
+
+ ACL_CHECK(aclDestroyIntArray(output_size_array));
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+/**
+ * @brief Pads a tensor with a specified value along each dimension.
+ *
+ * This function performs padding of the source tensor `acl_src` and stores the
+ * result in the destination tensor `acl_dst`. The padding values for each
+ * dimension are specified in the `paddings` array.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor to be padded.
+ * @param acl_dst The destination tensor where the padded result will be stored.
+ * @param paddings An array specifying the padding values for each dimension.
+ * The size of the array should be twice the number of dimensions of the tensor.
+ * @param value The value to be used for padding. The default value is 0.0.
+ */
+static void aclnn_pad(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+ aclTensor* acl_dst, int64_t* paddings,
+ float value = 0.0f) {
+ aclIntArray* acl_pad = aclCreateIntArray(paddings, GGML_MAX_DIMS * 2);
+ aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnConstantPadNdGetWorkspaceSize(
+ acl_src, acl_pad, acl_value, acl_dst, &workspaceSize, &executor));
+
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnConstantPadNd(workspaceAddr, workspaceSize, executor,
+ ctx.stream()));
+
+ ACL_CHECK(aclDestroyIntArray(acl_pad));
+ ACL_CHECK(aclDestroyScalar(acl_value));
+}
+
+void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ // padding: value in the array means how much distance will be padding.
+ // the position of elements in the array means which dirction to padding,
+ // each position means: [dim0.front, dim0.behind, dim1.front, dim1.behind,
+ // dim2.front, dim2.behind, dim3.front, dim3.behind]
+ int64_t paddings[] = {
+ 0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1],
+ 0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]};
+ aclnn_pad(ctx, acl_src, acl_dst, paddings);
+
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+ ACL_CHECK(aclDestroyTensor(acl_src));
+}
+
+/**
+ * @brief Performs 2D average pooling on the input tensor and stores the result
+ * in the destination tensor.
+ *
+ * This function performs average pooling on the source tensor and stores the
+ * result in the destination tensor. The pooling parameters (kernel size,
+ * strides, padding) are specified in the `op_params` of the destination tensor.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param dst The destination tensor where the result will be stored. The source
+ * tensor is referenced by `dst->src[0]`.
+ */
+static void ggml_cann_avg_pool2d(ggml_backend_cann_context& ctx,
+ ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+ GGML_ASSERT(src->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ aclTensor* acl_src =
+ ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
+ aclTensor* acl_dst =
+ ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
+
+ const int32_t* opts = (const int32_t*)dst->op_params;
+ const int k0 = opts[1];
+ const int k1 = opts[2];
+ const int s0 = opts[3];
+ const int s1 = opts[4];
+ const int p0 = opts[5];
+ const int p1 = opts[6];
+
+ std::vector<int64_t> kernel_dims = {k1, k0};
+ std::vector<int64_t> stride_dims = {s1, s0};
+ std::vector<int64_t> padding_avg_dims = {p1, p0}; // (padH, padW)
+
+ auto* kernel_size = aclCreateIntArray(kernel_dims.data(), 2);
+ auto* strides = aclCreateIntArray(stride_dims.data(), 2);
+ auto* paddings_avg = aclCreateIntArray(padding_avg_dims.data(), 2);
+
+ bool ceil_mode = false;
+ bool count_include_pad = true;
+ int64_t divisor_override = 0;
+ int8_t cube_math_type = 0;
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnAvgPool2dGetWorkspaceSize(
+ acl_src, kernel_size, strides, paddings_avg, ceil_mode,
+ count_include_pad, divisor_override, cube_math_type, acl_dst,
+ &workspaceSize, &executor));
+
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+ ACL_CHECK(
+ aclnnAvgPool2d(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+ ACL_CHECK(aclDestroyIntArray(kernel_size));
+ ACL_CHECK(aclDestroyIntArray(strides));
+ ACL_CHECK(aclDestroyIntArray(paddings_avg));
+}
+
+/**
+ * @brief Performs 2D max pooling on the input tensor and stores the result in
+ * the destination tensor.
+ *
+ * This function performs max pooling on the source tensor and stores the result
+ * in the destination tensor. The pooling parameters (kernel size, strides,
+ * padding) are specified in the `op_params` of the destination tensor.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param dst The destination tensor where the result will be stored. The source
+ * tensor is referenced by `dst->src[0]`.
+ */
+static void ggml_cann_max_pool2d(ggml_backend_cann_context& ctx,
+ ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+ GGML_ASSERT(src->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ aclTensor* acl_src =
+ ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
+ aclTensor* acl_dst =
+ ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
+
+ const int32_t* opts = (const int32_t*)dst->op_params;
+ const int k0 = opts[1];
+ const int k1 = opts[2];
+ const int s0 = opts[3];
+ const int s1 = opts[4];
+ const int p0 = opts[5];
+ const int p1 = opts[6];
+
+ int64_t temp_ne[] = {src->ne[0] + p0 * 2, src->ne[1] + p1 * 2, src->ne[2],
+ src->ne[3]};
+ size_t temp_nb[GGML_MAX_DIMS];
+
+ temp_nb[0] = ggml_element_size(src);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ temp_nb[i] = temp_nb[i - 1] * temp_ne[i - 1];
+ }
+
+ ggml_cann_pool_alloc temp_buffer_allocator(
+ ctx.pool(), ggml_nbytes(src) + p0 * 2 + p1 * 2 * src->nb[1]);
+ void* buffer = temp_buffer_allocator.get();
+ aclTensor* tmp_tensor = ggml_cann_create_tensor(
+ buffer, ACL_FLOAT, ggml_element_size(src), temp_ne, temp_nb,
+ GGML_MAX_DIMS, ACL_FORMAT_NCHW);
+
+ // pad: see padding in ggml_cann_pad()
+ int64_t paddings[] = {p0, p0, p1, p1, 0, 0, 0, 0};
+ float value = -FLT_MAX;
+ aclnn_pad(ctx, acl_src, tmp_tensor, paddings, value);
+
+ // max_pool
+ std::vector<int64_t> kernel_dims = {k1, k0};
+ std::vector<int64_t> stride_dims = {s1, s0};
+ // padding_max_dims: [dim0_start, dim0_end, dim1_start, dim1_end]
+ std::vector<int64_t> padding_max_dims = {0, 0, 0, 0};
+ std::vector<int64_t> dilation_size = {1, 1};
+ auto* kernel_size = aclCreateIntArray(kernel_dims.data(), 2);
+ auto* strides = aclCreateIntArray(stride_dims.data(), 2);
+ auto* paddings_max = aclCreateIntArray(padding_max_dims.data(), 4);
+ auto* dilations = aclCreateIntArray(dilation_size.data(), 2);
+
+ bool ceil_mode = false;
+ int64_t auto_pads = 0;
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnMaxPoolGetWorkspaceSize(
+ tmp_tensor, kernel_size, strides, auto_pads, paddings_max, dilations,
+ ceil_mode, acl_dst, &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnMaxPool(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+ ACL_CHECK(aclDestroyTensor(tmp_tensor));
+ ACL_CHECK(aclDestroyIntArray(kernel_size));
+ ACL_CHECK(aclDestroyIntArray(strides));
+ ACL_CHECK(aclDestroyIntArray(paddings_max));
+ ACL_CHECK(aclDestroyIntArray(dilations));
+}
+
+void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ const int32_t* opts = (const int32_t*)dst->op_params;
+ enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
+ switch (op) {
+ case GGML_OP_POOL_AVG:
+ ggml_cann_avg_pool2d(ctx, dst);
+ break;
+ case GGML_OP_POOL_MAX:
+ ggml_cann_max_pool2d(ctx, dst);
+ break;
+ case GGML_OP_POOL_COUNT:
+ GGML_ASSERT(false);
+ break;
+ }
+}
+
+/**
+ * @brief Copies data from the source tensor to the destination tensor.
+ *
+ * This function copies data from the source tensor `acl_src` to the destination
+ * tensor `acl_dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor from which data will be copied.
+ * @param acl_dst The destination tensor where the data will be copied to.
+ */
+static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+ aclTensor* acl_dst) {
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnInplaceCopyGetWorkspaceSize(acl_dst, acl_src, &workspaceSize,
+ &executor));
+
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnInplaceCopy(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ ggml_cann_pool_alloc src_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
+ ggml_cann_pool_alloc dst_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
+ src->extra = src_extra_allocator.get();
+ dst->extra = dst_extra_allocator.get();
+ ACL_CHECK(aclrtMemcpyAsync(src->extra, sizeof(ggml_tensor), src,
+ sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
+ ctx.stream()));
+ ACL_CHECK(aclrtMemcpyAsync(dst->extra, sizeof(ggml_tensor), dst,
+ sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
+ ctx.stream()));
+
+ if ((dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32) &&
+ ggml_are_same_shape(src, dst)) {
+ cann_copy(ctx, acl_src, acl_dst);
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+ return;
+ }
+ // TODO: simplify
+ if (src->type == GGML_TYPE_F16) {
+ if (dst->type == GGML_TYPE_Q8_0) {
+ aclrtlaunch_ascendc_quantize_f16_q8_0(
+ 24, ctx.stream(), src->data, dst->data,
+ ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
+ ((ggml_tensor*)dst->extra)->ne);
+ return;
+ }
+ if (dst->type == GGML_TYPE_F16) {
+ if (ggml_are_same_shape(src, dst)) {
+ cann_copy(ctx, acl_src, acl_dst);
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+ return;
+ }
+ if (ggml_is_contiguous(dst)) {
+ const size_t src_type_size = ggml_type_size(src->type);
+ if (src->nb[0] == src_type_size) {
+ // src0 is contigous on first dimension, copy by rows
+ int64_t rows_num = ggml_nrows(src);
+
+ aclrtlaunch_ascendc_dup_by_rows_fp16(
+ rows_num, ctx.stream(), src->data, dst->data,
+ ((ggml_tensor*)src->extra)->ne,
+ ((ggml_tensor*)src->extra)->nb,
+ ((ggml_tensor*)dst->extra)->ne,
+ ((ggml_tensor*)dst->extra)->nb);
+ return;
+ }
+ GGML_ASSERT(false);
+ }
+ GGML_ASSERT(false);
+ }
+ if (dst->type == GGML_TYPE_F32) {
+ if (ggml_are_same_shape(src, dst)) {
+ cann_copy(ctx, acl_src, acl_dst);
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+ return;
+ }
+ if (ggml_is_contiguous(dst)) {
+ const size_t src_type_size = ggml_type_size(src->type);
+ if (src->nb[0] == src_type_size) {
+ // src0 is contigous on first dimension, copy by rows
+ int64_t rows_num = ggml_nrows(src);
+ aclrtlaunch_ascendc_dup_by_rows_fp16_to_fp32(
+ rows_num, ctx.stream(), src->data, dst->data,
+ ((ggml_tensor*)src->extra)->ne,
+ ((ggml_tensor*)src->extra)->nb,
+ ((ggml_tensor*)dst->extra)->ne,
+ ((ggml_tensor*)dst->extra)->nb);
+ return;
+ }
+ GGML_ASSERT(false);
+ }
+ GGML_ASSERT(false);
+ }
+ // TODO
+ GGML_ASSERT(false);
+ } else if (src->type == GGML_TYPE_F32) {
+ // TODO: if (src0->type == dst->type && ne00 == ne0 && nb00 == type_size
+ // && nb0 == type_size)
+ if (dst->type == GGML_TYPE_Q8_0) {
+ aclrtlaunch_ascendc_quantize_f32_q8_0(
+ 24, ctx.stream(), src->data, dst->data,
+ ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
+ ((ggml_tensor*)dst->extra)->ne);
+ return;
+ }
+ if (dst->type == GGML_TYPE_F32) {
+ if (ggml_are_same_shape(src, dst)) {
+ cann_copy(ctx, acl_src, acl_dst);
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+ return;
+ }
+ if (ggml_is_contiguous(dst)) {
+ const size_t src_type_size = ggml_type_size(src->type);
+ if (src->nb[0] == src_type_size) {
+ // src0 is contigous on first dimension, copy by rows
+ int64_t rows_num = ggml_nrows(src);
+ aclrtlaunch_ascendc_dup_by_rows_fp32(
+ rows_num, ctx.stream(), src->data, dst->data,
+ ((ggml_tensor*)src->extra)->ne,
+ ((ggml_tensor*)src->extra)->nb,
+ ((ggml_tensor*)dst->extra)->ne,
+ ((ggml_tensor*)dst->extra)->nb);
+ return;
+ }
+ GGML_ASSERT(false);
+ } else {
+ // TODO: dst not contiguous
+ GGML_ASSERT(false);
+ }
+ }
+ if (dst->type == GGML_TYPE_F16) {
+ if (ggml_are_same_shape(src, dst)) {
+ cann_copy(ctx, acl_src, acl_dst);
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+ return;
+ }
+ if (ggml_is_contiguous(dst)) {
+ const size_t src_type_size = ggml_type_size(src->type);
+ if (src->nb[0] == src_type_size) {
+ // src0 is contigous on first dimension, copy by rows
+ int64_t rows_num = ggml_nrows(src);
+ aclrtlaunch_ascendc_dup_by_rows_fp32_to_fp16(
+ rows_num, ctx.stream(), src->data, dst->data,
+ ((ggml_tensor*)src->extra)->ne,
+ ((ggml_tensor*)src->extra)->nb,
+ ((ggml_tensor*)dst->extra)->ne,
+ ((ggml_tensor*)dst->extra)->nb);
+ return;
+ }
+ GGML_ASSERT(false);
+ }
+ }
+ // TODO
+ GGML_ASSERT(false);
+ } else {
+ if (ggml_are_same_shape(src, dst)) {
+ cann_copy(ctx, acl_src, acl_dst);
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+ return;
+ }
+ GGML_ASSERT(false);
+ }
+}
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+aclnnStatus aclnnRmsNormGetWorkspaceSize(const aclTensor* x,
+ const aclTensor* gamma, double epsilon,
+ const aclTensor* yOut,
+ const aclTensor* rstdOout,
+ uint64_t* workspaceSize,
+ aclOpExecutor** executor);
+aclnnStatus aclnnRmsNorm(void* workspace, uint64_t workspaceSize,
+ aclOpExecutor* executor, aclrtStream stream);
+#ifdef __cplusplus
+}
+#endif
+
+/**
+ * @brief Creates an ACL tensor initialized with zeros using a provided buffer.
+ *
+ * This function initializes a tensor with zeros using the specified buffer and
+ * tensor parameters.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param buffer The buffer to be used for the tensor data.
+ * @param n_bytes The size of the buffer in bytes.
+ * @param ne An array specifying the extents (sizes) of each dimension of the
+ * tensor.
+ * @param dims The number of dimensions of the tensor.
+ * @param type The data type of the tensor.
+ * @param type_size The size of each element in the tensor data type.
+ * @return An ACL tensor initialized with zeros.
+ */
+static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer,
+ size_t n_bytes, int64_t* ne, int64_t dims,
+ aclDataType type, size_t type_size) {
+ size_t nb[GGML_MAX_DIMS];
+ nb[0] = type_size;
+ for (int i = 1; i < dims; i++) {
+ nb[i] = nb[i - 1] * ne[i - 1];
+ }
+
+ ACL_CHECK(aclrtMemsetAsync(buffer, n_bytes, 0, n_bytes, ctx.stream()));
+ aclTensor* zero =
+ ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims);
+ return zero;
+}
+
+/**
+ * @brief Creates an ACL tensor initialized with ones using a provided buffer.
+ *
+ * This function initializes a tensor with ones using the specified buffer and
+ * tensor parameters.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param buffer The buffer to be used for the tensor data.
+ * @param n_bytes The size of the buffer in bytes.
+ * @param ne An array specifying the extents (sizes) of each dimension of the
+ * tensor.
+ * @param dims The number of dimensions of the tensor.
+ * @param type The data type of the tensor.
+ * @param type_size The size of each element in the tensor data type.
+ * @param value The value to be used for initializing the tensor (default
+ * is 1.0).
+ * @return An ACL tensor initialized with ones.
+ */
+static aclTensor* aclnn_ones(ggml_backend_cann_context& ctx, void* buffer,
+ size_t n_bytes, int64_t* ne, int64_t dims,
+ aclDataType type, size_t type_size,
+ float value = 1.0f) {
+ aclTensor* acl_tensor =
+ aclnn_zero(ctx, buffer, n_bytes, ne, dims, type, type_size);
+ float alpha_host = 1.0f;
+ aclScalar* alpha = aclCreateScalar(&alpha_host, aclDataType::ACL_FLOAT);
+ aclScalar* other = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnInplaceAddsGetWorkspaceSize(acl_tensor, other, alpha,
+ &workspaceSize, &executor));
+
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+ ACL_CHECK(
+ aclnnInplaceAdds(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ return acl_tensor;
+}
+
+void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ GGML_ASSERT(eps > 0.0f);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ size_t one_tensor_n_bytes = src->ne[0] * ggml_element_size(src);
+ ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
+
+ aclTensor* acl_gamma = aclnn_ones(
+ ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne, 1,
+ ggml_cann_type_mapping(src->type), ggml_element_size(src));
+
+ size_t zero_tensor_n_bytes =
+ src->ne[1] * src->ne[2] * src->ne[3] * ggml_element_size(src);
+ ggml_cann_pool_alloc zero_tensor_allocator(ctx.pool(), zero_tensor_n_bytes);
+ aclTensor* acl_rstd =
+ aclnn_zero(ctx, zero_tensor_allocator.get(), zero_tensor_n_bytes,
+ src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
+ ggml_element_size(src));
+
+ ACL_CHECK(aclnnRmsNormGetWorkspaceSize(
+ acl_src, acl_gamma, eps, acl_dst, acl_rstd, &workspaceSize, &executor));
+
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnRmsNorm(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+ ACL_CHECK(aclDestroyTensor(acl_gamma));
+ ACL_CHECK(aclDestroyTensor(acl_rstd));
+}
+
+// TODO: performace is low.
+void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
+ float value) {
+ ggml_tensor* src = dst->src[0];
+
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ const int n_past = ((int32_t*)dst->op_params)[0];
+
+ size_t one_tensor_n_bytes = src->ne[0] * src->ne[1] * src->ne[2] *
+ src->ne[3] * ggml_element_size(src);
+ ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
+
+ aclTensor* mask_tensor =
+ aclnn_ones(ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne,
+ GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
+ ggml_element_size(src), value);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnInplaceTriuGetWorkspaceSize(mask_tensor, n_past + 1,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnInplaceTriu(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclnnTrilGetWorkspaceSize(acl_src, n_past + 1, acl_dst,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnTril(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ aclScalar* alpha = nullptr;
+ float alphaValue = 1.0f;
+ alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
+
+ ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, mask_tensor, alpha,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+ ACL_CHECK(
+ aclnnInplaceAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyScalar(alpha));
+ ACL_CHECK(aclDestroyTensor(mask_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+/**
+ * @brief Casts the data type of a source tensor to a destination tensor.
+ *
+ * This function casts the data type of the source tensor `acl_src` to the
+ * specified data type `cast_data_type` and stores the result in the destination
+ * tensor `acl_dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor whose data type will be casted.
+ * @param acl_dst The destination tensor where the casted result will be stored.
+ * @param cast_data_type The target data type to which the source tensor will be
+ * casted.
+ */
+static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+ aclTensor* acl_dst, aclDataType cast_data_type) {
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnCastGetWorkspaceSize(acl_src, cast_data_type, acl_dst,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnCast(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+/**
+ * @brief Permutes the dimensions of a tensor according to a specified order.
+ *
+ * This function permutes the dimensions of the source tensor `acl_src`
+ * according to the order specified in the `new_dim` array and stores the result
+ * in the destination tensor `acl_dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor whose dimensions will be permuted.
+ * @param acl_dst The destination tensor where the permuted result will be
+ * stored.
+ * @param new_dim An array specifying the new order of dimensions for the
+ * tensor.
+ * @param dims The number of dimensions in the tensor.
+ */
+static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+ aclTensor* acl_dst, int64_t* new_dim, uint64_t dims) {
+ aclIntArray* acl_dims = aclCreateIntArray(new_dim, dims);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnPermuteGetWorkspaceSize(acl_src, acl_dims, acl_dst,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnPermute(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyIntArray(acl_dims));
+}
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+aclnnStatus aclnnIm2colGetWorkspaceSize(const aclTensor* self,
+ const aclIntArray* kernelSize,
+ const aclIntArray* dilation,
+ const aclIntArray* padding,
+ const aclIntArray* stride,
+ aclTensor* out, uint64_t* workspaceSize,
+ aclOpExecutor** executor);
+aclnnStatus aclnnIm2col(void* workspace, uint64_t workspaceSize,
+ aclOpExecutor* executor, aclrtStream stream);
+#ifdef __cplusplus
+}
+#endif
+void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src0 = dst->src[0]; // kernel
+ ggml_tensor* src1 = dst->src[1]; // input
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
+
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+ const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+ const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
+ const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ const int64_t N = is_2D ? ne13 : ne12;
+ const int64_t IC = is_2D ? ne12 : ne11;
+
+ const int64_t KH = is_2D ? ne01 : 1;
+ const int64_t KW = ne00;
+
+ const int64_t OH = is_2D ? ne2 : 1;
+ const int64_t OW = ne1;
+
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ // im2col: [N,C,H,W] -> [N, IC * KH * KW, OW * OH]
+ aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
+ int64_t tmp_im2col_ne[] = {OW * OH, IC * KH * KW, N};
+ size_t tmp_im2col_nb[GGML_MAX_DIMS - 1];
+
+ tmp_im2col_nb[0] = ggml_type_size(src1->type);
+ for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
+ tmp_im2col_nb[i] = tmp_im2col_nb[i - 1] * tmp_im2col_ne[i - 1];
+ }
+
+ // Calculate im2col.
+ // If dst is f16, tmp_buffer is f32, we need alloc src.typesize *
+ // dst.elemcount.
+ ggml_cann_pool_alloc im2col_allocator(
+ ctx.pool(), ggml_nelements(dst) * ggml_element_size(src1));
+ void* tmp_im2col_buffer = im2col_allocator.get();
+ aclTensor* tmp_im2col_tensor = ggml_cann_create_tensor(
+ tmp_im2col_buffer, ggml_cann_type_mapping(src1->type),
+ ggml_type_size(src1->type), tmp_im2col_ne, tmp_im2col_nb,
+ GGML_MAX_DIMS - 1, ACL_FORMAT_ND);
+
+ std::vector<int64_t> kernel_dims = {KH, KW};
+ std::vector<int64_t> dilation_size = {d1, d0};
+ std::vector<int64_t> padding_dims = {p1, p0};
+ std::vector<int64_t> stride_dims = {s1, s0};
+ auto* kernel_size = aclCreateIntArray(kernel_dims.data(), 2);
+ auto* dilations = aclCreateIntArray(dilation_size.data(), 2);
+ auto* paddings = aclCreateIntArray(padding_dims.data(), 2);
+ auto* strides = aclCreateIntArray(stride_dims.data(), 2);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnIm2colGetWorkspaceSize(acl_src1, kernel_size, dilations,
+ paddings, strides, tmp_im2col_tensor,
+ &workspaceSize, &executor));
+
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnIm2col(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ // Cast if dst is f16.
+ aclTensor* tmp_cast_tensor = nullptr;
+ ggml_cann_pool_alloc tmp_cast_allocator(ctx.pool());
+ if (src1->type != dst->type) {
+ tmp_cast_allocator.alloc(ggml_nbytes(dst));
+ void* tmp_cast_buffer = tmp_cast_allocator.get();
+ size_t temp_cast_nb[GGML_MAX_DIMS - 1];
+ temp_cast_nb[0] = ggml_type_size(dst->type);
+ for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
+ temp_cast_nb[i] = temp_cast_nb[i - 1] * tmp_im2col_ne[i - 1];
+ }
+
+ tmp_cast_tensor = ggml_cann_create_tensor(
+ tmp_cast_buffer, ggml_cann_type_mapping(dst->type),
+ ggml_type_size(dst->type), tmp_im2col_ne, temp_cast_nb,
+ GGML_MAX_DIMS - 1, ACL_FORMAT_ND);
+ aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor,
+ ggml_cann_type_mapping(dst->type));
+ }
+
+ // Permute: [N, IC * KH * KW, OW * OH] -> [N, OW * OH, IC * KH * KW]
+ int64_t dst_ne[] = {dst->ne[0], dst->ne[1] * dst->ne[2], dst->ne[3]};
+ size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[3]};
+ aclTensor* acl_dst =
+ ggml_cann_create_tensor(dst, dst_ne, dst_nb, GGML_MAX_DIMS - 1);
+
+ int64_t permute_dim[] = {0, 2, 1};
+ if (src1->type != dst->type) {
+ aclnn_permute(ctx, tmp_cast_tensor, acl_dst, permute_dim, 3);
+ } else {
+ aclnn_permute(ctx, tmp_im2col_tensor, acl_dst, permute_dim, 3);
+ }
+
+ // release
+ ACL_CHECK(aclDestroyTensor(acl_src1));
+ ACL_CHECK(aclDestroyTensor(tmp_im2col_tensor));
+ ACL_CHECK(aclDestroyTensor(tmp_cast_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+ ACL_CHECK(aclDestroyIntArray(kernel_size));
+ ACL_CHECK(aclDestroyIntArray(dilations));
+ ACL_CHECK(aclDestroyIntArray(paddings));
+ ACL_CHECK(aclDestroyIntArray(strides));
+}
+
+/**
+ * @brief Applies element-wise exponential function to the elements of a tensor.
+ *
+ * This function computes the exponential of each element in the source tensor
+ * `acl_src` and stores the result back into the same tensor.
+ * The operation is defined as:
+ * \f[
+ * \text {acl_src }_i=e^{acl\_src_i}
+ * \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The tensor on which the exponential function will be applied.
+ */
+static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(
+ aclnnInplaceExpGetWorkspaceSize(acl_src, &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnInplaceExp(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+/**
+ * @brief Multiplies elements of a tensor by a scalar value, optionally
+ * in-place.
+ *
+ * This function multiplies each element of the source tensor `acl_src` by the
+ * scalar `scale` and stores the result in the destination tensor `acl_dst`. If
+ * `inplace` is true, `acl_dst` will not be used and the operation is performed
+ * in-place on `acl_src`.
+ * The operation is defined as:
+ * \f[
+ * \text {acl_dst }_i=\text {acl_src }_i \times \text {scale}
+ * \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor whose elements will be multiplied.
+ * @param scale The scalar value by which each element of `acl_src` will be
+ * multiplied.
+ * @param acl_dst The destination tensor where the result will be stored if
+ * `inplace` is false.
+ * @param inplace Flag indicating whether to perform the operation in-place on
+ * `acl_src`.
+ */
+static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+ float scale, aclTensor* acl_dst, bool inplace) {
+ aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ if (inplace) {
+ ACL_CHECK(aclnnInplaceMulsGetWorkspaceSize(acl_src, acl_scale,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnInplaceMuls(workspaceAddr, workspaceSize, executor,
+ ctx.stream()));
+ } else {
+ ACL_CHECK(aclnnMulsGetWorkspaceSize(acl_src, acl_scale, acl_dst,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnMuls(workspaceAddr, workspaceSize, executor, ctx.stream()));
+ }
+
+ ACL_CHECK(aclDestroyScalar(acl_scale));
+}
+
+/**
+ * @brief Performs an in-place element-wise multiplication of two tensors.
+ *
+ * This function performs an element-wise multiplication of the tensors
+ * `acl_src` and `acl_other` and stores the result in `acl_src`.
+ * The operation is defined as:
+ * \f[
+ * \text {acl_src }_i=\text {acl_src }_i \times \text {acl_other }_i
+ * \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor where the multiplication result will be
+ * stored.
+ * @param acl_other The tensor whose elements will be multiplied with `acl_src`.
+ */
+static void aclnn_inplace_mul(ggml_backend_cann_context& ctx,
+ aclTensor* acl_src, aclTensor* acl_other) {
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnInplaceMulGetWorkspaceSize(acl_src, acl_other,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnInplaceMul(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+/**
+ * @brief Performs element-wise multiplication of two tensors and stores the
+ * result in a destination tensor.
+ *
+ * This function performs element-wise multiplication of the tensors `acl_src`
+ * and `acl_other` and stores the result in the destination tensor `acl_dst`.
+ * The operation is defined as:
+ * \f[
+ * \text {acl_dst }_i=\text {acl_src }_i \times \text {acl_other }_i
+ * \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The first tensor for element-wise multiplication.
+ * @param acl_other The second tensor for element-wise multiplication.
+ * @param acl_dst The destination tensor where the result will be stored.
+ */
+static void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+ aclTensor* acl_other, aclTensor* acl_dst) {
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnMulGetWorkspaceSize(acl_src, acl_other, acl_dst,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnMul(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+/**
+ * @brief Applies element-wise cosine function to the elements of a tensor.
+ *
+ * This function computes the cosine of each element in the source tensor
+ * `acl_src` and stores the result in the destination tensor `acl_dst`. The
+ * operation is defined as: \f[ \text {acl_dst }_i=\cos \left(\text {acl_src
+ * }_i\right) \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor on which the cosine function will be
+ * applied.
+ * @param acl_dst The destination tensor where the cosine results will be
+ * stored.
+ */
+static void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+ aclTensor* acl_dst) {
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(
+ aclnnCosGetWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnCos(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+/**
+ * @brief Applies element-wise sine function to the elements of a tensor.
+ *
+ * This function computes the sine of each element in the source tensor
+ `acl_src`
+ * and stores the result in the destination tensor `acl_dst`.
+ * The operation is defined as:
+ * \f[
+ * \text {acl_dst }_i=\sin \left(\text {acl_src }_i\right)
+ * \f]
+
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor on which the sine function will be applied.
+ * @param acl_dst The destination tensor where the sine results will be stored.
+ */
+static void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+ aclTensor* acl_dst) {
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(
+ aclnnSinGetWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnSin(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
+ ggml_tensor* dst) {
+ const ggml_tensor* src = dst->src[0];
+
+ GGML_ASSERT(src->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const int dim = dst->op_params[0];
+ const int max_period = dst->op_params[1];
+ int half = dim / 2;
+
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
+
+ // arange: [0, ..., half)
+ float start = 0;
+ float stop = half;
+ float step = 1;
+ int64_t n_elements_arange = half;
+ int64_t tmp_arange_ne[] = {half};
+ size_t tmp_arange_nb[] = {sizeof(dst->type)};
+
+ ggml_cann_pool_alloc arange_allocator(ctx.pool(), half * sizeof(dst->type));
+ void* tmp_arange_buffer = arange_allocator.get();
+ aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
+ tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
+ ggml_type_size(dst->type), tmp_arange_ne, tmp_arange_nb,
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+
+ aclnn_arange(ctx, tmp_arange_tensor, start, stop, step, n_elements_arange);
+
+ // freq
+ float freq_param = -logf(max_period) / half;
+ bool inplace = true;
+ aclnn_muls(ctx, tmp_arange_tensor, freq_param, nullptr, inplace);
+ aclnn_exp(ctx, tmp_arange_tensor);
+
+ // permute: src [0,1,2,3]->[0,1,3,2]
+ int64_t tmp_permute_ne[] = {src->ne[1], src->ne[0], src->ne[2], src->ne[3]};
+ size_t tmp_permute_nb[GGML_MAX_DIMS];
+ tmp_permute_nb[0] = ggml_type_size(src->type);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1];
+ }
+
+ ggml_cann_pool_alloc permute_allocator(ctx.pool(), ggml_nbytes(src));
+ void* tmp_permute_buffer = permute_allocator.get();
+ aclTensor* tmp_permute_tenosr = ggml_cann_create_tensor(
+ tmp_permute_buffer, ggml_cann_type_mapping(src->type),
+ ggml_type_size(src->type), tmp_permute_ne, tmp_permute_nb,
+ GGML_MAX_DIMS, ACL_FORMAT_ND);
+ int64_t permute_dim[] = {0, 1, 3, 2};
+ int64_t num_dims = 4;
+ aclnn_permute(ctx, acl_src, tmp_permute_tenosr, permute_dim, num_dims);
+
+ // timestep * freq
+ int64_t tmp_mul_ne[] = {src->ne[1] * half, src->ne[0], src->ne[2],
+ src->ne[3]};
+ size_t tmp_mul_nb[GGML_MAX_DIMS];
+ tmp_mul_nb[0] = ggml_type_size(src->type);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ tmp_mul_nb[i] = tmp_mul_nb[i - 1] * tmp_mul_ne[i - 1];
+ }
+
+ int mul_nelements =
+ src->ne[1] * half * src->ne[0] * src->ne[2] * src->ne[3];
+
+ ggml_cann_pool_alloc mul_allocator(
+ ctx.pool(), mul_nelements * ggml_type_size(src->type));
+ void* tmp_mul_buffer = mul_allocator.get();
+ aclTensor* tmp_mul_tensor = ggml_cann_create_tensor(
+ tmp_mul_buffer, ggml_cann_type_mapping(src->type),
+ ggml_type_size(src->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS,
+ ACL_FORMAT_ND);
+ aclnn_mul(ctx, tmp_permute_tenosr, tmp_arange_tensor, tmp_mul_tensor);
+
+ // cos
+ ggml_cann_pool_alloc cos_allocator(
+ ctx.pool(), mul_nelements * ggml_type_size(src->type));
+ void* tmp_cos_buffer = cos_allocator.get();
+ aclTensor* tmp_cos_tensor = ggml_cann_create_tensor(
+ tmp_cos_buffer, ggml_cann_type_mapping(dst->type),
+ ggml_type_size(dst->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS,
+ ACL_FORMAT_ND);
+
+ aclnn_cos(ctx, tmp_mul_tensor, tmp_cos_tensor);
+
+ // sin
+ ggml_cann_pool_alloc sin_allocator(
+ ctx.pool(), mul_nelements * ggml_type_size(src->type));
+ void* tmp_sin_buffer = sin_allocator.get();
+ aclTensor* tmp_sin_tensor = ggml_cann_create_tensor(
+ tmp_sin_buffer, ggml_cann_type_mapping(dst->type),
+ ggml_type_size(dst->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS,
+ ACL_FORMAT_ND);
+
+ aclnn_sin(ctx, tmp_mul_tensor, tmp_sin_tensor);
+
+ // concat
+ int64_t concat_dim = 3;
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+ aclTensor* tensors[] = {tmp_cos_tensor, tmp_sin_tensor};
+ aclTensorList* tensorList = aclCreateTensorList(tensors, 2);
+ aclnn_concat(ctx, tensorList, acl_dst, concat_dim);
+
+ // release
+ // segmentation fault when delete both tensorList and his elements.
+ ACL_CHECK(aclDestroyTensorList(tensorList));
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(tmp_arange_tensor));
+ ACL_CHECK(aclDestroyTensor(tmp_permute_tenosr));
+ ACL_CHECK(aclDestroyTensor(tmp_mul_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+/**
+ * @brief Fills a tensor with a scalar value.
+ *
+ * This function fills the destination tensor `acl_dst` with the scalar value
+ * `scalar`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param scalar The scalar value used to fill the tensor.
+ * @param acl_dst The destination tensor to be filled with the scalar value.
+ */
+static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
+ aclTensor* acl_dst) {
+ auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnInplaceFillScalarGetWorkspaceSize(
+ acl_dst, acl_scalar, &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnInplaceFillScalar(workspaceAddr, workspaceSize, executor,
+ ctx.stream()));
+ ACL_CHECK(aclDestroyScalar(acl_scalar));
+}
+
+/**
+ * @brief Raises each element of a tensor to the power of the corresponding
+ * element in another tensor.
+ *
+ * This function computes the element-wise power of the destination tensor
+ * `acl_dst` raised to the power of the exponent tensor `acl_exp`.
+ * The operation is defined as:
+ * \f[
+ * \text {acl_dst }_i=acl\_dst_i^{\text {acl_exp }_i}
+ * \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_dst The destination tensor, which also serves as the base tensor.
+ * @param acl_exp The exponent tensor, each element of which is used to raise
+ * the corresponding element in the destination tensor.
+ */
+static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
+ aclTensor* acl_dst, aclTensor* acl_exp) {
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnInplacePowTensorTensorGetWorkspaceSize(
+ acl_dst, acl_exp, &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnInplacePowTensorTensor(workspaceAddr, workspaceSize,
+ executor, ctx.stream()));
+}
+
+/**
+ * @brief Applies the Alibi (Attention with Linear Biases) mechanism to the
+ * @details This function implements the Alibi mechanism, which introduces
+ * learnable biases into the attention scores to simulate relative
+ * position encoding without the need for explicit positional
+ * embeddings.
+ *
+ * @param ctx The backend CANN context for executing operations.
+ * @param acl_src The source tensor representing the query or key.
+ * @param acl_position The position tensor containing relative positions.
+ * @param acl_dst The destination tensor where the result will be stored.
+ * @param n_head The number of attention heads.
+ * @param src_ne The dimensions of the source tensor.
+ * @param src_nb0 The byte size of the first dimension of the source
+ tensor.
+ * @param max_bias The maximum bias value used in the Alibi mechanism.
+ * @param dst The destination tensor object for additional metadata.
+ *
+ * The function performs the following steps:
+ * 1. Calculates the logarithm floor of the number of heads to determine the
+ base for bias calculation.
+ * 2. Initializes arrays with arithmetic sequences and fills them with bias
+ values.
+ * 3. Computes the bias tensor based on the calculated biases and arithmetic
+ sequences.
+ * 4. Reshapes the bias tensor to match the dimensions of the input tensors.
+ * 5. Multiplies the position tensor by the bias tensor.
+ * 6. Adds the result of the multiplication to the source tensor to produce the
+ final output.
+ */
+static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+ aclTensor* acl_position, aclTensor* acl_dst,
+ const int n_head, int64_t* src_ne, const size_t src_nb0,
+ float max_bias, ggml_tensor* dst) {
+ const int64_t ne2_ne3 = src_ne[2] * src_ne[3];
+ GGML_ASSERT(src_nb0 == sizeof(float));
+ GGML_ASSERT(n_head == src_ne[2]);
+
+ const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
+
+ float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
+
+ // init arange
+ ggml_cann_pool_alloc arange_allocator(ctx.pool(),
+ ne2_ne3 * ggml_type_size(dst->type));
+ void* tmp_arange_buffer = arange_allocator.get();
+
+ // arange1: [1, ..., n_heads_log2_floor+1)
+ float start = 1;
+ float stop = n_heads_log2_floor + 1;
+ float step = 1;
+ int64_t n_elements_arange = n_heads_log2_floor;
+
+ int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
+ size_t tmp_arange1_nb[] = {sizeof(dst->type)};
+ aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
+ tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
+ ggml_type_size(dst->type), tmp_arange1_ne, tmp_arange1_nb,
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+
+ aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
+
+ aclTensor* tmp_arange2_tensor = nullptr;
+ if (n_heads_log2_floor < ne2_ne3) {
+ // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
+ start = 1;
+ stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
+ step = 2;
+ n_elements_arange = ne2_ne3 - n_heads_log2_floor;
+ int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
+ size_t tmp_arange2_nb[] = {sizeof(dst->type)};
+
+ aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
+ (char*)tmp_arange_buffer +
+ n_heads_log2_floor * ggml_type_size(dst->type),
+ ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
+ tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+ aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
+ n_elements_arange);
+ }
+
+ // init mk_base
+ ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
+ ne2_ne3 * ggml_type_size(dst->type));
+ void* tmp_mk_base_buffer = mk_base_allocator.get();
+ int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
+ size_t tmp_mk_base1_nb[] = {sizeof(dst->type)};
+ aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
+ tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
+ ggml_type_size(dst->type), tmp_mk_base1_ne, tmp_mk_base1_nb,
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+
+ aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
+
+ aclTensor* tmp_mk_base2_tensor = nullptr;
+ if (n_heads_log2_floor < ne2_ne3) {
+ int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
+ size_t tmp_mk_base2_nb[] = {sizeof(dst->type)};
+ aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
+ (char*)tmp_mk_base_buffer +
+ n_heads_log2_floor * ggml_type_size(dst->type),
+ ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
+ tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+ aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
+ }
+
+ // init mk
+ int64_t tmp_mk_base_ne[] = {ne2_ne3};
+ size_t tmp_mk_base_nb[] = {sizeof(dst->type)};
+ aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
+ tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
+ ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+ aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
+ tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
+ ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+ aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
+
+ // reshape mk
+ int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]};
+ size_t tmp_mk_nb[GGML_MAX_DIMS];
+ tmp_mk_nb[0] = ggml_type_size(dst->type);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
+ }
+ aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
+ tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
+ ggml_type_size(dst->type), tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
+ ACL_FORMAT_ND);
+
+ // acl_position * mk
+ int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]};
+ size_t tmp_output_nb[GGML_MAX_DIMS];
+ tmp_output_nb[0] = ggml_type_size(dst->type);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ tmp_output_nb[i] = tmp_output_nb[i - 1] * tmp_output_ne[i - 1];
+ }
+ ggml_cann_pool_alloc output_allocator(ctx.pool(), ggml_nbytes(dst));
+ void* tmp_output_buffer = output_allocator.get();
+ aclTensor* tmp_output_tensor = ggml_cann_create_tensor(
+ tmp_output_buffer, ggml_cann_type_mapping(dst->type),
+ ggml_type_size(dst->type), tmp_output_ne, tmp_output_nb, GGML_MAX_DIMS,
+ ACL_FORMAT_ND);
+ aclnn_mul(ctx, acl_position, tmp_mk_tensor, tmp_output_tensor);
+
+ // add
+ aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst);
+
+ ACL_CHECK(aclDestroyTensor(tmp_arange1_tensor));
+ ACL_CHECK(aclDestroyTensor(tmp_arange2_tensor));
+ ACL_CHECK(aclDestroyTensor(tmp_mk_base1_tensor));
+ ACL_CHECK(aclDestroyTensor(tmp_mk_base2_tensor));
+ ACL_CHECK(aclDestroyTensor(tmp_mk_base_tensor));
+ ACL_CHECK(aclDestroyTensor(tmp_arange_tensor));
+ ACL_CHECK(aclDestroyTensor(tmp_mk_tensor));
+ ACL_CHECK(aclDestroyTensor(tmp_output_tensor));
+}
+
+void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_cann_dup(ctx, dst);
+}
+
+/**
+ * @brief Performs element-wise addition of two tensors in place.
+ *
+ * This function adds the source tensor `acl_src` to the destination tensor
+ * `acl_dst` element-wise and stores the result in the destination tensor
+ * `acl_dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor to be added.
+ * @param acl_dst The destination tensor which will hold the result of the
+ * addition.
+ */
+static void aclnn_inplace_add(ggml_backend_cann_context& ctx,
+ aclTensor* acl_src, aclTensor* acl_dst) {
+ aclScalar* alpha = nullptr;
+ float alphaValue = 1.0f;
+ alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, acl_src, alpha,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnInplaceAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyScalar(alpha));
+}
+
+/**
+ * @brief Applies the softmax function to a tensor along a specified dimension.
+ *
+ * This function computes the softmax of the source tensor `acl_src` along the
+ * specified dimension `dim` and stores the result in the destination tensor
+ * `acl_dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor on which the softmax function will be
+ * applied.
+ * @param dim The dimension along which the softmax function will be computed.
+ * @param acl_dst The destination tensor where the softmax results will be
+ * stored.
+ */
+static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+ int64_t dim, aclTensor* acl_dst) {
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnSoftmaxGetWorkspaceSize(acl_src, dim, acl_dst,
+ &workspaceSize, &executor));
+
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ aclrtStream stream = ctx.stream();
+ ACL_CHECK(aclnnSoftmax(workspaceAddr, workspaceSize, executor, stream));
+}
+
+void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src0 = dst->src[0];
+ ggml_tensor* src1 = dst->src[1]; // mask
+
+ aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (float*)dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float));
+
+ // input mul scale
+ aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
+
+ size_t n_bytes = ggml_nbytes(src0);
+ ggml_cann_pool_alloc mul_scale_allocator(ctx.pool(), n_bytes);
+ void* input_mul_scale_buffer = mul_scale_allocator.get();
+ aclTensor* acl_input_mul_scale_tensor = ggml_cann_create_tensor(
+ input_mul_scale_buffer, ACL_FLOAT, ggml_type_size(src0->type), src0->ne,
+ src0->nb, GGML_MAX_DIMS);
+
+ bool inplace = false;
+ aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace);
+
+ // mask
+ aclTensor* acl_src1_fp32_tensor = nullptr;
+ aclTensor* tmp_mask_tensor = nullptr;
+ ggml_cann_pool_alloc src1_fp32_allocator(ctx.pool());
+ if (src1) {
+ const bool use_f16 = src1->type == GGML_TYPE_F16;
+ if (use_f16) {
+ // cast to fp32
+ size_t n_bytes = ggml_nelements(src1) * sizeof(float_t);
+ size_t src1_fp32_nb[GGML_MAX_DIMS];
+ src1_fp32_nb[0] = sizeof(float_t);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1];
+ }
+ src1_fp32_allocator.alloc(n_bytes);
+ void* src1_fp32_buffer = src1_fp32_allocator.get();
+ acl_src1_fp32_tensor = ggml_cann_create_tensor(
+ src1_fp32_buffer, ACL_FLOAT, sizeof(float), src1->ne,
+ src1_fp32_nb, GGML_MAX_DIMS);
+ aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
+ aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT);
+
+ ACL_CHECK(aclDestroyTensor(acl_src1));
+ } else {
+ acl_src1_fp32_tensor = ggml_cann_create_tensor(src1);
+ }
+
+ // broadcast the mask across rows, only use ne11 of ne01 in mask
+ if (src1->ne[1] != src0->ne[1]) {
+ // mask shape: [1,1,ne11,ne10]
+ int64_t tmp_mask_ne[] = {src0->ne[0], src0->ne[1], 1, 1};
+ size_t tmp_mask_nb[GGML_MAX_DIMS];
+ tmp_mask_nb[0] = sizeof(float_t);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ tmp_mask_nb[i] = tmp_mask_nb[i - 1] * tmp_mask_ne[i - 1];
+ }
+ tmp_mask_tensor = ggml_cann_create_tensor(
+ src1->data, ACL_FLOAT, sizeof(float), tmp_mask_ne, tmp_mask_nb,
+ GGML_MAX_DIMS, ACL_FORMAT_ND);
+ }
+
+ // alibi
+ const int n_head = src0->ne[2];
+ const size_t src_nb0 = src0->nb[0];
+
+ n_bytes = ggml_nbytes(dst);
+ ggml_cann_pool_alloc output_allocator(ctx.pool(), n_bytes);
+ void* output_buffer = output_allocator.get();
+ aclTensor* alibi_output_tensor = ggml_cann_create_tensor(
+ output_buffer, ACL_FLOAT, ggml_type_size(dst->type), dst->ne,
+ dst->nb, GGML_MAX_DIMS);
+ if (max_bias <= 0.0f) {
+ // slope = 1.0
+ if (tmp_mask_tensor) {
+ aclnn_add(ctx, tmp_mask_tensor, acl_input_mul_scale_tensor,
+ alibi_output_tensor);
+ } else {
+ aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor,
+ alibi_output_tensor);
+ }
+ } else {
+ // slope != 1.0
+ if (tmp_mask_tensor) {
+ aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_mask_tensor,
+ alibi_output_tensor, n_head, src0->ne, src_nb0,
+ max_bias, dst);
+ } else {
+ aclnn_alibi(ctx, acl_input_mul_scale_tensor,
+ acl_src1_fp32_tensor, alibi_output_tensor, n_head,
+ src0->ne, src_nb0, max_bias, dst);
+ }
+ }
+
+ // softmax
+ aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst);
+ ACL_CHECK(aclDestroyTensor(alibi_output_tensor));
+ } else {
+ aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst);
+ }
+
+ ACL_CHECK(aclDestroyTensor(acl_src0));
+ ACL_CHECK(aclDestroyTensor(acl_src1_fp32_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+ ACL_CHECK(aclDestroyScalar(acl_scale));
+ ACL_CHECK(aclDestroyTensor(acl_input_mul_scale_tensor));
+ ACL_CHECK(aclDestroyTensor(tmp_mask_tensor));
+}
+
+void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src0 = dst->src[0];
+ ggml_tensor* src1 = dst->src[1];
+
+ ggml_cann_pool_alloc src0_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
+ ggml_cann_pool_alloc src1_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
+ ggml_cann_pool_alloc dst_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
+ src0->extra = src0_extra_allocator.get();
+ src1->extra = src1_extra_allocator.get();
+ dst->extra = dst_extra_allocator.get();
+ ACL_CHECK(aclrtMemcpyAsync(src0->extra, sizeof(ggml_tensor), src0,
+ sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
+ ctx.stream()));
+ ACL_CHECK(aclrtMemcpyAsync(src1->extra, sizeof(ggml_tensor), src1,
+ sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
+ ctx.stream()));
+ ACL_CHECK(aclrtMemcpyAsync(dst->extra, sizeof(ggml_tensor), dst,
+ sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
+ ctx.stream()));
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ aclrtlaunch_ascendc_get_row_f32(
+ 24, ctx.stream(), src0->data, src1->data, dst->data,
+ ((ggml_tensor*)src0->extra)->ne,
+ ((ggml_tensor*)src0->extra)->nb,
+ ((ggml_tensor*)src1->extra)->ne,
+ ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
+ ((ggml_tensor*)dst->extra)->nb);
+ break;
+ case GGML_TYPE_F16:
+ aclrtlaunch_ascendc_get_row_f16(
+ 24, ctx.stream(), src0->data, src1->data, dst->data,
+ ((ggml_tensor*)src0->extra)->ne,
+ ((ggml_tensor*)src0->extra)->nb,
+ ((ggml_tensor*)src1->extra)->ne,
+ ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
+ ((ggml_tensor*)dst->extra)->nb);
+ break;
+ case GGML_TYPE_Q4_0:
+ aclrtlaunch_ascendc_get_row_q4_0(
+ 24, ctx.stream(), src0->data, src1->data, dst->data,
+ ((ggml_tensor*)src0->extra)->ne,
+ ((ggml_tensor*)src1->extra)->ne,
+ ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
+ ((ggml_tensor*)dst->extra)->nb);
+ break;
+ case GGML_TYPE_Q8_0:
+ aclrtlaunch_ascendc_get_row_q8_0(
+ 24, ctx.stream(), src0->data, src1->data, dst->data,
+ ((ggml_tensor*)src0->extra)->ne,
+ ((ggml_tensor*)src1->extra)->ne,
+ ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
+ ((ggml_tensor*)dst->extra)->nb);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+}
+
+/**
+ * @brief Repeats elements of a tensor along a specified dimension.
+ *
+ * This function repeats each element of the source tensor `acl_src` a specified
+ * number of times (`repeats`) along the specified dimension `dim` and stores
+ * the result in the destination tensor `acl_dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor whose elements will be repeated.
+ * @param acl_dst The destination tensor where the repeated elements will be
+ * stored.
+ * @param dim The dimension along which the elements will be repeated.
+ * @param repeats The number of times each element will be repeated.
+ * @param output_size The size of the output tensor.
+ */
+static void aclnn_repeat_interleave(ggml_backend_cann_context& ctx,
+ aclTensor* acl_src, aclTensor* acl_dst,
+ int64_t dim, int64_t repeats,
+ int64_t output_size) {
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnRepeatInterleaveIntWithDimGetWorkspaceSize(
+ acl_src, repeats, dim, output_size, acl_dst, &workspaceSize,
+ &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnRepeatInterleaveIntWithDim(workspaceAddr, workspaceSize,
+ executor, ctx.stream()));
+}
+
+/**
+ * @brief Performs matrix multiplication of two tensors.
+ *
+ * This function computes the matrix multiplication of the input tensor
+ * `acl_input` and the weight tensor `acl_weight`, and stores the result in the
+ * destination tensor `acl_dst`.
+ * The operation is defined as:
+ * \f[
+ * \text {acl_dst}=\text {acl_input@acl_weight}
+ * \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_input The input tensor for the matrix multiplication.
+ * @param acl_weight The weight tensor for the matrix multiplication.
+ * @param acl_dst The destination tensor where the result of the matrix
+ * multiplication will be stored.
+ */
+static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input,
+ aclTensor* acl_weight, aclTensor* acl_dst) {
+ int8_t cube_math_type = 1; // ALLOW_FP32_DOWN_PRECISION, when input is
+ // fp32, atlas a2 will transpose it to HFLOAT32.
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnMatmulGetWorkspaceSize(acl_input, acl_weight, acl_dst,
+ cube_math_type, &workspaceSize,
+ &executor));
+
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(
+ aclnnMatmul(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+/**
+ * @brief Performs matrix multiplication with floating-point precision on
+ * tensors using the CANN backend.
+ *
+ * This function performs matrix multiplication of the input tensor and the
+ * weight tensor, handling broadcasting and transposing as needed, and stores
+ * the result in the destination tensor `dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param dst The destination tensor where the result of the matrix
+ * multiplication will be stored.
+ */
+static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
+ ggml_tensor* dst) {
+ ggml_tensor* weight = dst->src[0]; // weight
+ ggml_tensor* input = dst->src[1]; // input
+
+ // when weight ne2 or ne3 is 1, aclnnMatmulGetWorkspaceSize will auto
+ // broadcast, when weight ne2 or ne3 is not 1, weight need repeat.
+ BCAST_MUL_MAT_SHAPE(input, weight, dst);
+
+ // transpose weight: [1,2,3,4] -> [1,2,4,3]
+ int64_t transpose_ne[] = {bcast_weight_ne[1], bcast_weight_ne[0],
+ bcast_weight_ne[2], bcast_weight_ne[3],
+ bcast_weight_ne[4], bcast_weight_ne[5]};
+ size_t transpose_nb[] = {bcast_weight_nb[1], bcast_weight_nb[0],
+ bcast_weight_nb[2], bcast_weight_nb[3],
+ bcast_weight_nb[4], bcast_weight_nb[5]};
+
+ aclTensor* acl_weight_tensor =
+ ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, bcast_dims);
+ aclTensor* acl_input_tensor =
+ ggml_cann_create_tensor(input, BCAST_MUL_MAT_PARAM(input));
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst, BCAST_MUL_MAT_PARAM(dst));
+ aclnn_mat_mul(ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
+
+ ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+/**
+ * @brief Performs matrix multiplication with quantized weights and
+ * floating-point inputs using the CANN backend.
+ *
+ * This function performs matrix multiplication of the input tensor `src1` and
+ * the weight tensor `src0`, handling broadcasting, transposing, and
+ * quantization as needed, and stores the result in the destination tensor
+ * `dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param dst The destination tensor where the result of the matrix
+ * multiplication will be stored.
+ */
+static void ggml_cann_mul_mat_q8_0(ggml_backend_cann_context& ctx,
+ ggml_tensor* dst) {
+ ggml_tensor* src0 = dst->src[0]; // weight
+ ggml_tensor* src1 = dst->src[1]; // input
+
+ // The shape of the weight is NCHW. Matrix multiplication uses HW dims. HC
+ // is regarded as batch. weight need transpose.
+ int64_t weight_ne[] = {src0->ne[1], src0->ne[0]};
+ size_t weight_elem_size = sizeof(uint8_t);
+ size_t weight_nb[] = {weight_elem_size * src0->ne[0], weight_elem_size};
+ // size of one matrix is element_size * height * width.
+ size_t weight_stride = weight_elem_size * src0->ne[0] * src0->ne[1];
+ size_t weight_size = weight_stride * src0->ne[2] * src0->ne[3];
+
+ // scale stored at the end of weight. Also need transpose.
+ int64_t scale_ne[] = {src0->ne[1], src0->ne[0] / QK8_0};
+ size_t scale_elem_size = sizeof(uint16_t);
+ size_t scale_nb[] = {src0->ne[0] / QK8_0 * scale_elem_size,
+ scale_elem_size};
+ size_t scale_stride = scale_elem_size * src0->ne[0] * src0->ne[1] / QK8_0;
+ char* scale_offset = (char*)src0->data + weight_size;
+
+ // input
+ void* input_buffer;
+ size_t input_elem_size = sizeof(uint16_t);
+ int64_t input_ne[] = {src1->ne[0], src1->ne[1]};
+ size_t input_nb[] = {input_elem_size, input_elem_size * src1->ne[0]};
+ size_t input_stride = input_elem_size * src1->ne[0] * src1->ne[1];
+
+ if (src1->type != GGML_TYPE_F16) {
+ aclTensor* acl_src1_tensor = ggml_cann_create_tensor(src1);
+ ggml_cann_pool_alloc input_alloctor(
+ ctx.pool(), ggml_nelements(src1) * input_elem_size);
+ input_buffer = input_alloctor.get();
+
+ int64_t* input_cast_ne = src1->ne;
+ size_t input_cast_nb[GGML_MAX_DIMS];
+ input_cast_nb[0] = sizeof(uint16_t);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ input_cast_nb[i] = input_cast_nb[i - 1] * input_cast_ne[i - 1];
+ }
+
+ aclTensor* acl_input_tensor = ggml_cann_create_tensor(
+ input_buffer, ACL_FLOAT16, input_elem_size, input_cast_ne,
+ input_cast_nb, GGML_MAX_DIMS);
+ aclnn_cast(ctx, acl_src1_tensor, acl_input_tensor, ACL_FLOAT16);
+ ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_src1_tensor));
+ } else {
+ input_buffer = src1->data;
+ }
+
+ // output
+ size_t output_elem_size = sizeof(uint16_t);
+ int64_t output_ne[] = {dst->ne[0], dst->ne[1]};
+ size_t output_nb[] = {output_elem_size, output_elem_size * dst->ne[0]};
+ ggml_cann_pool_alloc output_alloctor(
+ ctx.pool(), ggml_nelements(dst) * output_elem_size);
+ void* output_buffer = output_alloctor.get();
+ size_t output_stride = output_elem_size * dst->ne[0] * dst->ne[1];
+
+ // aclnn
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ for (int64_t n1 = 0; n1 < src1->ne[3]; n1++) {
+ for (int64_t c1 = 0; c1 < src1->ne[2]; c1++) {
+ int64_t n0 = n1 / (src1->ne[3] / src0->ne[3]);
+ int64_t c0 = c1 / (src1->ne[2] / src0->ne[2]);
+
+ int64_t batch1 = n1 * src1->ne[2] + c1;
+ int64_t batch0 = n0 * src0->ne[2] + c0;
+
+ aclTensor* acl_input_tensor = ggml_cann_create_tensor(
+ (char*)input_buffer + batch1 * input_stride, ACL_FLOAT16,
+ input_elem_size, input_ne, input_nb, 2);
+ aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
+ (char*)src0->data + batch0 * weight_stride, ACL_INT8,
+ weight_elem_size, weight_ne, weight_nb, 2);
+ aclTensor* acl_scale_tensor = ggml_cann_create_tensor(
+ scale_offset + batch0 * scale_stride, ACL_FLOAT16,
+ scale_elem_size, scale_ne, scale_nb, 2);
+ aclTensor* acl_output_tensor = ggml_cann_create_tensor(
+ (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
+ output_elem_size, output_ne, output_nb, 2);
+
+ ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
+ acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr,
+ nullptr, nullptr, nullptr, QK8_0, acl_output_tensor,
+ &workspaceSize, &executor));
+
+ if (workspaceSize > 0 && workspaceAddr == nullptr) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(),
+ workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
+ workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_scale_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_output_tensor));
+ }
+ }
+
+ // cast out
+ int64_t* output_cast_ne = dst->ne;
+ size_t output_cast_nb[GGML_MAX_DIMS];
+ output_cast_nb[0] = sizeof(uint16_t);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ output_cast_nb[i] = output_cast_nb[i - 1] * output_cast_ne[i - 1];
+ }
+
+ aclTensor* acl_output_tensor =
+ ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, output_elem_size,
+ output_cast_ne, output_cast_nb, GGML_MAX_DIMS);
+ aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
+ aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ACL_FLOAT);
+
+ ACL_CHECK(aclDestroyTensor(acl_output_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_dst_tensor));
+}
+
+void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ const enum ggml_type type = dst->src[0]->type;
+ switch (type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ ggml_cann_mat_mul_fp(ctx, dst);
+ break;
+ // case GGML_TYPE_Q4_0:
+ // ggml_cann_mul_mat_q4_0(ctx, dst);
+ // break;
+ case GGML_TYPE_Q8_0:
+ ggml_cann_mul_mat_q8_0(ctx, dst);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+}
+
+/**
+ * @brief Rolls the elements of a tensor along a specified dimension.
+ *
+ * This function rolls the elements of the source tensor `acl_src` by the
+ * specified shifts `shifts` along the specified dimensions `dims`, and stores
+ * the result in the destination tensor `acl_dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor whose elements will be rolled.
+ * @param acl_dst The destination tensor where the rolled elements will be
+ * stored.
+ * @param shifts An array specifying the number of positions by which elements
+ * are shifted.
+ * @param dims An array specifying the dimensions along which elements are
+ * shifted.
+ */
+static void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+ aclTensor* acl_dst, int64_t* shifts, int64_t* dims) {
+ aclIntArray* acl_shifts = aclCreateIntArray(shifts, 1);
+ aclIntArray* acl_dims = aclCreateIntArray(dims, 1);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnRollGetWorkspaceSize(acl_src, acl_shifts, acl_dims, acl_dst,
+ &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnRoll(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyIntArray(acl_shifts));
+ ACL_CHECK(aclDestroyIntArray(acl_dims));
+}
+
+/**
+ * @brief Fills specified positions of a tensor with a scalar value.
+ *
+ * This function fills the positions in the source tensor `acl_src` specified by
+ * `index` along the dimension `dim` with the scalar value `value`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor where the positions will be filled.
+ * @param dim The dimension along which the positions are specified.
+ * @param index An array specifying the positions to be filled.
+ * @param index_num The number of positions specified in the index array.
+ * @param value The scalar value used to fill the specified positions.
+ */
+static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
+ aclTensor* acl_src, int64_t dim,
+ int64_t* index, int64_t index_num,
+ float value) {
+ aclIntArray* acl_index = aclCreateIntArray(index, index_num);
+ aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(aclnnInplaceIndexFillTensorGetWorkspaceSize(
+ acl_src, dim, acl_index, acl_value, &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ ACL_CHECK(aclnnInplaceIndexFillTensor(workspaceAddr, workspaceSize,
+ executor, ctx.stream()));
+
+ ACL_CHECK(aclDestroyIntArray(acl_index));
+ ACL_CHECK(aclDestroyScalar(acl_value));
+}
+
+static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
+ aclTensor* acl_cos_repeat_tensor,
+ aclTensor* acl_sin_repeat_tensor,
+ float theta_scale, bool is_neox) {
+ // int sin/cos cache, cache has different repeat method depond on
+ // @param.is_neox
+
+ ggml_tensor* src0 = dst->src[0]; // input
+ ggml_tensor* src1 = dst->src[1]; // position
+
+ // arange, [0,1,...,ne0/2]
+ int64_t arange_length = src0->ne[0] / 2;
+ ggml_cann_pool_alloc arange_allocator(ctx.pool(),
+ arange_length * sizeof(float_t));
+ void* arange_buffer = arange_allocator.get();
+ int64_t arange_ne[] = {arange_length, 1, 1, 1};
+ size_t arange_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
+ arange_length * sizeof(float_t)};
+
+ aclTensor* acl_arange_tensor =
+ ggml_cann_create_tensor(arange_buffer, ACL_FLOAT, sizeof(float_t),
+ arange_ne, arange_nb, GGML_MAX_DIMS);
+ float start = 0;
+ float step = 1;
+ float stop = src0->ne[0] / 2;
+ float n_elements = src0->ne[0] / 2;
+ aclnn_arange(ctx, acl_arange_tensor, start, stop, step, n_elements);
+
+ // power
+ // aclnnPowScalarTensor(): @param self is tensor which should be scalar, so
+ // use aclnn_pow_tensor_tensor() until fixed. aclScalar* acl_theta_scale =
+ // aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
+ // aclnn_power_scalar_tensor(ctx, acl_theta_scale, acl_arange_tensor,
+ // acl_power_tensor);
+ ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(),
+ arange_length * sizeof(float_t));
+ void* theta_scale_buffer = theta_scale_allocator.get();
+ aclTensor* acl_theta_scale_tensor = aclnn_ones(
+ ctx, theta_scale_buffer, arange_length * sizeof(float_t), arange_ne,
+ GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), theta_scale);
+ aclnn_pow_tensor_tensor(ctx, acl_theta_scale_tensor, acl_arange_tensor);
+
+ // position
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
+ int64_t position_length = src1->ne[0];
+ int64_t position_ne[] = {1, position_length, 1, 1};
+ size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t),
+ sizeof(int32_t) * position_length,
+ sizeof(int32_t) * position_length};
+ aclTensor* acl_position_tensor = ggml_cann_create_tensor(
+ src1->data, ggml_cann_type_mapping(src1->type),
+ ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
+
+ // power * position
+ int64_t theta_length = arange_length * position_length;
+ ggml_cann_pool_alloc theta_allocator(ctx.pool(),
+ theta_length * sizeof(float_t));
+ void* theta_buffer = theta_allocator.get();
+ int64_t theta_ne[] = {arange_length, position_length, 1, 1};
+ size_t theta_nb[GGML_MAX_DIMS];
+ theta_nb[0] = sizeof(float_t);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
+ }
+ aclTensor* acl_theta_tensor =
+ ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
+ theta_ne, theta_nb, GGML_MAX_DIMS);
+ aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
+ acl_theta_tensor);
+
+ // permute: [0,1,2,3]->[0,2,1,3]
+ int64_t permute_ne[] = {arange_length, 1, position_length, 1};
+ size_t permute_nb[GGML_MAX_DIMS];
+ permute_nb[0] = sizeof(float_t);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ permute_nb[i] = permute_nb[i - 1] * permute_ne[i - 1];
+ }
+ ggml_cann_pool_alloc permute_allocator(ctx.pool(),
+ theta_length * sizeof(float_t));
+ void* permute_buffer = permute_allocator.get();
+ aclTensor* acl_permute_tensor = ggml_cann_create_tensor(
+ permute_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb,
+ GGML_MAX_DIMS, ACL_FORMAT_ND);
+ int64_t permute_dim[] = {0, 2, 1, 3};
+ int64_t num_dims = 4;
+ aclnn_permute(ctx, acl_theta_tensor, acl_permute_tensor, permute_dim,
+ num_dims);
+
+ // sin/cos
+ ggml_cann_pool_alloc sin_allocator(ctx.pool(),
+ theta_length * sizeof(float_t));
+ void* sin_buffer = sin_allocator.get();
+ aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
+ sin_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb,
+ GGML_MAX_DIMS, ACL_FORMAT_ND);
+ aclnn_sin(ctx, acl_permute_tensor, acl_sin_tensor);
+
+ ggml_cann_pool_alloc cos_allocator(ctx.pool(),
+ theta_length * sizeof(float_t));
+ void* cos_buffer = cos_allocator.get();
+ aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
+ cos_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb,
+ GGML_MAX_DIMS, ACL_FORMAT_ND);
+ aclnn_cos(ctx, acl_permute_tensor, acl_cos_tensor);
+
+ // repeat
+ if (is_neox) {
+ int64_t repeatsArray[] = {1, 1, 1, 2};
+ aclnn_repeat(ctx, acl_sin_tensor, acl_sin_repeat_tensor, repeatsArray);
+ aclnn_repeat(ctx, acl_cos_tensor, acl_cos_repeat_tensor, repeatsArray);
+ } else {
+ int64_t num_repeats = 2;
+ int64_t dim = 3;
+ int64_t output_size = arange_length * num_repeats;
+ aclnn_repeat_interleave(ctx, acl_sin_tensor, acl_sin_repeat_tensor, dim,
+ num_repeats, output_size);
+ aclnn_repeat_interleave(ctx, acl_cos_tensor, acl_cos_repeat_tensor, dim,
+ num_repeats, output_size);
+ }
+
+ // release
+ ACL_CHECK(aclDestroyTensor(acl_arange_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_theta_scale_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_position_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_theta_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_permute_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_sin_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_cos_tensor));
+}
+
+void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ // TODO: use ascendc
+ // Only test with LLAMA model.
+ ggml_tensor* src0 = dst->src[0]; // input
+ ggml_tensor* src2 = dst->src[2]; // freq_factors
+
+ // TODO: with freq_factors
+ GGML_ASSERT(src2 == NULL);
+
+ // param
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+ // const int n_past = ((int32_t *) dst->op_params)[0];
+ const int n_dims = ((int32_t*)dst->op_params)[1];
+ const int mode = ((int32_t*)dst->op_params)[2];
+ // const int n_ctx = ((int32_t *) dst->op_params)[3];
+ const int n_ctx_orig = ((int32_t*)dst->op_params)[4];
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ memcpy(&freq_base, (int32_t*)dst->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t*)dst->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t*)dst->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t*)dst->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t*)dst->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t*)dst->op_params + 10, sizeof(float));
+
+ GGML_ASSERT(n_dims <= ne0);
+ GGML_ASSERT(n_dims % 2 == 0);
+
+ // TODO: ext_factor != 0
+ GGML_ASSERT(ext_factor == 0);
+ // TODO: freq_scale != 1
+ GGML_ASSERT(freq_scale == 1);
+
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
+
+ float corr_dims[2];
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
+ beta_slow, corr_dims);
+
+ const bool is_neox = mode & 2;
+
+ // init cos/sin cache
+ ggml_cann_pool_alloc sin_allocator(
+ ctx.pool(), src0->ne[0] * src0->ne[2] * sizeof(float_t));
+ ggml_cann_pool_alloc cos_allocator(
+ ctx.pool(), src0->ne[0] * src0->ne[2] * sizeof(float_t));
+ void* sin_buffer = sin_allocator.get();
+ void* cos_buffer = cos_allocator.get();
+
+ int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
+ size_t sin_reshape_nb[GGML_MAX_DIMS];
+ sin_reshape_nb[0] = sizeof(float_t);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
+ }
+ aclTensor* acl_sin_reshape_tensor =
+ ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float_t),
+ sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
+ aclTensor* acl_cos_reshape_tensor =
+ ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
+ sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
+ aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
+ theta_scale, is_neox);
+
+ // roll input
+ void* input_roll_buffer;
+ aclTensor* acl_minus_one_tensor;
+ void* minus_one_scale_buffer = nullptr;
+ ggml_cann_pool_alloc roll_allocator(ctx.pool(), ggml_nbytes(src0));
+ ggml_cann_pool_alloc minus_one_scale_allocator(
+ ctx.pool(), sizeof(float_t) * src0->ne[0]);
+ if (!is_neox) {
+ // roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...]
+ input_roll_buffer = roll_allocator.get();
+ int64_t input_roll_ne[4] = {2, src0->ne[1] * (src0->ne[0] / 2),
+ src0->ne[2], src0->ne[3]};
+ size_t input_roll_nb[GGML_MAX_DIMS];
+ input_roll_nb[0] = ggml_type_size(src0->type);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ input_roll_nb[i] = input_roll_nb[i - 1] * input_roll_ne[i - 1];
+ }
+ aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor(
+ input_roll_buffer, ggml_cann_type_mapping(src0->type),
+ ggml_type_size(src0->type), input_roll_ne, input_roll_nb,
+ GGML_MAX_DIMS);
+ aclTensor* acl_input_tensor = ggml_cann_create_tensor(
+ src0->data, ggml_cann_type_mapping(src0->type),
+ ggml_type_size(src0->type), input_roll_ne, input_roll_nb,
+ GGML_MAX_DIMS);
+
+ int64_t shifts[] = {1};
+ int64_t dims[] = {3};
+ aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
+ ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+
+ // init [-1, 1, -1, 1, ...]
+ minus_one_scale_buffer = minus_one_scale_allocator.get();
+
+ int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
+ size_t minus_one_nb[GGML_MAX_DIMS];
+ minus_one_nb[0] = sizeof(float_t);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
+ }
+ acl_minus_one_tensor = aclnn_ones(
+ ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
+ minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
+ int64_t dim = 3;
+ int64_t* index = new int64_t[src0->ne[0]];
+ for (int i = 0; i < src0->ne[0]; i++) {
+ index[i] = i / 2 * 2;
+ }
+ int64_t index_num = src0->ne[0];
+ float value = -1;
+ aclnn_index_fill_tensor(ctx, acl_minus_one_tensor, dim, index,
+ index_num, value);
+ } else {
+ // roll input: [q0,q1,q2,...] ->
+ // [q_half,q_half+1,...,q_end,q0,q1,...q_half-1]
+ input_roll_buffer = roll_allocator.get();
+ aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor(
+ input_roll_buffer, ggml_cann_type_mapping(src0->type),
+ ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS);
+ aclTensor* acl_input_tensor = ggml_cann_create_tensor(src0);
+
+ int64_t shifts[] = {src0->ne[0] / 2};
+ int64_t dims[] = {3};
+ aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
+
+ ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+
+ // init [-1, -1, -1, 1, 1,1,...]
+ minus_one_scale_buffer = minus_one_scale_allocator.get();
+
+ int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
+ size_t minus_one_nb[GGML_MAX_DIMS];
+ minus_one_nb[0] = sizeof(float_t);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
+ }
+ acl_minus_one_tensor = aclnn_ones(
+ ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
+ minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
+ // -1 * first half
+ int64_t first_half_ne[4] = {src0->ne[0] / 2, 1, 1, 1};
+ size_t first_half_nb[GGML_MAX_DIMS];
+ first_half_nb[0] = sizeof(float_t);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ first_half_nb[i] = first_half_nb[i - 1] * first_half_ne[i - 1];
+ }
+ aclTensor* acl_first_half_tensor = ggml_cann_create_tensor(
+ minus_one_scale_buffer, ACL_FLOAT, sizeof(float_t), first_half_ne,
+ first_half_nb, GGML_MAX_DIMS);
+ bool inplace = true;
+ float scale = -1;
+ aclnn_muls(ctx, acl_first_half_tensor, scale, nullptr, inplace);
+ ACL_CHECK(aclDestroyTensor(acl_first_half_tensor));
+ }
+
+ // TODO: n_dims < ne0
+ GGML_ASSERT(n_dims == src0->ne[0]);
+
+ // input * scale
+ ggml_cann_pool_alloc roll_mul_scale_allocator(ctx.pool(),
+ ggml_nbytes(src0));
+ void* input_roll_mul_scale_buffer = roll_mul_scale_allocator.get();
+ size_t input_nb[GGML_MAX_DIMS];
+ input_nb[0] = ggml_type_size(src0->type);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ input_nb[i] = input_nb[i - 1] * src0->ne[i - 1];
+ }
+ aclTensor* acl_input_roll_mul_scale_tensor = ggml_cann_create_tensor(
+ input_roll_mul_scale_buffer, ggml_cann_type_mapping(src0->type),
+ ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS);
+ aclTensor* acl_input_roll_reshape_tensor = ggml_cann_create_tensor(
+ input_roll_buffer, ggml_cann_type_mapping(src0->type),
+ ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS);
+
+ aclnn_mul(ctx, acl_input_roll_reshape_tensor, acl_minus_one_tensor,
+ acl_input_roll_mul_scale_tensor);
+
+ // output
+ aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+ void* output_fp32_buffer;
+ if (src0->type == GGML_TYPE_F32) {
+ aclnn_inplace_mul(ctx, acl_src0, acl_cos_reshape_tensor);
+ aclnn_inplace_mul(ctx, acl_input_roll_mul_scale_tensor,
+ acl_sin_reshape_tensor);
+ aclnn_add(ctx, acl_src0, acl_input_roll_mul_scale_tensor, acl_dst);
+ // TODO: ne0 != n_dims in mode2
+ } else if (src0->type == GGML_TYPE_F16) {
+ size_t input_fp32_nb[GGML_MAX_DIMS];
+ input_fp32_nb[0] = sizeof(float_t);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ input_fp32_nb[i] = input_fp32_nb[i - 1] * dst->ne[i - 1];
+ }
+ ggml_cann_pool_alloc fp32_allocator1(
+ ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
+ void* input_fp32_buffer1 = fp32_allocator1.get();
+ aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor(
+ input_fp32_buffer1, ACL_FLOAT, sizeof(float_t), dst->ne,
+ input_fp32_nb, GGML_MAX_DIMS);
+ ggml_cann_pool_alloc fp32_allocator2(
+ ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
+ void* input_fp32_buffer2 = fp32_allocator2.get();
+ aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor(
+ input_fp32_buffer2, ACL_FLOAT, sizeof(float_t), dst->ne,
+ input_fp32_nb, GGML_MAX_DIMS);
+
+ ggml_cann_pool_alloc fp32_allocator(
+ ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
+ output_fp32_buffer = fp32_allocator.get();
+ aclTensor* output_fp32_tensor = ggml_cann_create_tensor(
+ output_fp32_buffer, ACL_FLOAT, sizeof(float_t), dst->ne,
+ input_fp32_nb, GGML_MAX_DIMS);
+ aclnn_mul(ctx, acl_src0, acl_cos_reshape_tensor, input_fp32_tensor1);
+ aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor,
+ input_fp32_tensor2);
+ aclnn_add(ctx, input_fp32_tensor1, input_fp32_tensor2,
+ output_fp32_tensor);
+ aclnn_cast(ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16);
+
+ ACL_CHECK(aclDestroyTensor(input_fp32_tensor1));
+ ACL_CHECK(aclDestroyTensor(input_fp32_tensor2));
+ ACL_CHECK(aclDestroyTensor(output_fp32_tensor));
+ }
+
+ ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_minus_one_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_input_roll_mul_scale_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_input_roll_reshape_tensor));
+ ACL_CHECK(aclDestroyTensor(acl_src0));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h
new file mode 100644
index 00000000..680129c7
--- /dev/null
+++ b/ggml/src/ggml-cann/aclnn_ops.h
@@ -0,0 +1,592 @@
+#ifndef CANN_ACLNN_OPS
+#define CANN_ACLNN_OPS
+
+/**
+ * @file acl_tensor
+ * @brief This file contains related functions of ggml_tensor and acl_tensor.
+ * Contains conversion from ggml_tensor to acl_tensor, broadcast and other
+ * functions.
+ * @author hipudding <huafengchun@gmail.com>
+ * @author wangshuai09 <391746016@qq.com>
+ * @date July 15, 2024
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#include <aclnnop/aclnn_add.h>
+#include <aclnnop/aclnn_arange.h>
+#include <aclnnop/aclnn_argsort.h>
+#include <aclnnop/aclnn_cat.h>
+#include <aclnnop/aclnn_clamp.h>
+#include <aclnnop/aclnn_div.h>
+#include <aclnnop/aclnn_gelu.h>
+#include <aclnnop/aclnn_hardsigmoid.h>
+#include <aclnnop/aclnn_hardswish.h>
+#include <aclnnop/aclnn_leaky_relu.h>
+#include <aclnnop/aclnn_mul.h>
+#include <aclnnop/aclnn_relu.h>
+#include <aclnnop/aclnn_silu.h>
+#include <aclnnop/aclnn_tanh.h>
+#include "acl_tensor.h"
+#include "common.h"
+
+/**
+ * @brief Repeats a ggml tensor along each dimension to match the dimensions
+ * of another tensor.
+ *
+ * @details This function repeats the elements of a source ggml tensor along
+ * each dimension to create a destination tensor with the specified
+ * dimensions. The operation is performed using the ACL backend and
+ * executed asynchronously on the device.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The ggml tensor representing the destination, which op is
+ * GGML_OP_REPEAT and specifies the desired dimensions.
+ */
+void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Adds two ggml tensors using the CANN backend.
+ *
+ * @details This function performs an element-wise addition of two tensors. In
+ * case the tensors do not have the same shape, one or both tensors
+ * will be broadcasted to match the shape of the other before the
+ * addition is performed.The formula for the operation is given by:
+ * \f[
+ * \text{dst} = \text{acl_src0} + \alpha \cdot \text{acl_src1}
+ * \f]
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The ggml tensor representing the destination, result of the
+ * addition is stored at dst->data, and dst->op is `GGML_OP_ADD`
+ */
+void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Applies the Leaky ReLU activation function to a tensor using the CANN
+ * backend.
+ *
+ * @details This function computes the Leaky ReLU activation for each element of
+ * the input tensor. The Leaky ReLU function allows a small gradient
+ * when the unit is not active (i.e., when the input is negative). The
+ * Leaky ReLU function is defined as:
+ * \f[
+ * \text{dst} = \max(0, src) + \text{negativeSlope} \cdot \min(0,
+ * src)
+ * \f]
+ * `negativeSlope` is in dst->params.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the result of the Leaky ReLU
+ * activation is stored, which op is `GGML_OP_LEAKY_RELU`
+ */
+void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Concatenates multiple tensors along a specified dimension using the
+ * CANN backend.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param tensorList A pointer to the list of tensors to be concatenated.
+ * @param dst The destination tensor where the result of the
+ * concatenation is stored. dst->op is `GGML_OP_CONCAT`.
+ * @param concat_dim The dimension along which the tensors are concatenated.
+ *
+ * @attention tensorList length should be 2 and the dimension using for concat
+ * default to 1.
+ */
+void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Generates a sequence of evenly spaced values within a specified
+ * interval for a ggml tensor using the CANN backend.
+ *
+ * @details This function creates a sequence of numbers over a specified i
+ * nterval, starting from `start`, ending before `stop`, and
+ * incrementing by `step`. The sequence is stored in the destination
+ * tensor `dst`.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the generated sequence will be stored.
+ * `start`, 'stop' and 'step' are in dst->op_params and dst->op is
+ * `GGML_OP_ARANGE`.
+ */
+void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Computes the square of the elements of a ggml tensor using the CANN
+ * backend.
+ * @details The function sets the second source tensor of the destination
+ * tensor `dst` to be equal to the first source tensor. This is
+ * effectively squaring the elements since the multiplication becomes
+ * `element * element`.
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the squared values will be stored,
+ * which dst->op is `GGML_OP_SQR`.
+ */
+void ggml_cann_sqr(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Applies a clamp operation to the elements of a ggml tensor using the
+ * CANN backend.
+ *
+ * @details This function clamps the elements of the input tensor `src` to a
+ * specified range defined by `min` and `max` values. The result is
+ * stored in the destination tensor `dst`. The operation is defined as:
+ * \f[
+ * y = \max(\min(x, max\_value), min\_value)
+ * \f]
+ * where `x` is an element of the input tensor, and `y` is the
+ * corresponding element in the output tensor.
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the clamped values will be stored.
+ * dst->op is `GGML_OP_CLAMP`, `min` and `max` value is in dst->params.
+ */
+void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Scales the elements of a ggml tensor by a constant factor using the
+ * CANN backend.
+ *
+ * @details This function multiplies each element of the input tensor `src` by
+ * a scaling factor `scale`, storing the result in the destination
+ * tensor `dst`. The operation is defined as:
+ * \f[
+ * dst = src \times scale
+ * \f]
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the scaled values will be stored.
+ * dst->op is `GGML_OP_SCALE` and `scale` value is in dst->params.
+ */
+void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Sorts the elements of a ggml tensor and returns the indices that
+ * would sort the tensor using the CANN backend.
+ *
+ * @details This function performs an argsort operation on the input tensor
+ * `src`. It sorts the elements of `src` in either ascending or
+ * descending order, depending on the `GGML_SORT_ORDER_DESC`,
+ * and returns the indices that would sort the original tensor.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the sorted indices will be stored.
+ * dst->op is `GGML_OP_ARGSORT`.
+ */
+void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Computes the Layer Normalization for a ggml tensor using the CANN
+ * backend.
+ *
+ * @details This function applies the Layer Normalization operation on the
+ * input tensor `src` and stores the result in the destination tensor
+ * `dst`. Layer Normalization normalizes the features at each sample in
+ * a mini-batch independently. It is commonly used in neural networks
+ * to normalize the activations of a layer by adjusting and scaling
+ * the outputs.
+ * The operation is defined as:
+ * \f[
+ * \text { out }=\frac{x-\mathrm{E}[x]}{\sqrt{\text{Var}[x]+eps}}
+ * \f]
+ * `Var` defaults dst->ne[0]. `eps` is in dst->params.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the normalized values will be stored.
+ * @attention `Var` defaults to dst->ne[0].
+ */
+void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Computes the Group Normalization for a ggml tensor using the CANN
+ * backend.
+ *
+ * @brief This function applies the Group Normalization operation on the input
+ * tensor `src` and stores the result in the destination tensor `dst`.
+ * Group Normalization divides the channels into groups and normalizes
+ * the features within each group across spatial locations.
+ * It is commonly used in convolutional neural networks to improve
+ * training stability and performance.
+ * The operation is defined as:
+ * \f[
+ * \text { out }=\frac{x-\mathrm{E}[x]}{\sqrt{\text{Var}[x]+eps}}
+ * \f]
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the normalized values will be stored.
+ * `n_groups` is in dst->params, which split C channel to `n_groups`.
+ * dst->op is `GGML_OP_GROUP_NORM`.
+ *
+ * @attention eps defaults to 1e-6f.
+ */
+void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Computes the accumulation of tensors using the CANN backend.
+ *
+ * @details This function performs an accumulation operation on two tensors.
+ * Depending on the `inplace` flag, it either updates the destination
+ * tensor `dst` in place by adding `alpha * src1` to it, or it creates
+ * a new tensor as the result of `src0 + alpha * src1` and stores it in
+ * `dst`.
+ * The operation is defined as:
+ * \f[
+ * dst = src0 + alpha \times src1
+ * \f]
+ * if `inplace` is `true`, `src0` is equal to 'dst'.
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the accumulated values will be stored.
+ * `inplace` is in dst->params, and dst->op is `GGML_OP_ACC`.
+ */
+void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Computes the sum of elements along the last dimension of a ggml tensor
+ * using the CANN backend.
+ *
+ * @details This function performs a reduction sum operation along the last
+ * dimension of the input tensor `src`. The result of the sum is stored
+ * in the destination tensor `dst`.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the reduced values will be stored。
+ * dst->op is `GGML_OP_SUM_ROWS`.
+ *
+ * @attention `reduce_dims` defaults to 3, which means the last dimension.
+ */
+void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Upsamples a ggml tensor using nearest neighbor interpolation using
+ * the CANN backend.
+ *
+ * @details This function performs upsampling of the input tensor `src` using
+ * nearest neighbor interpolation. The upsampling is applied to the
+ * height and width dimensions (last two dimensions) of the tensor. The
+ * result is stored in the destination tensor `dst`, which must have
+ * the appropriate dimensions for the upsampled output.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the upsampled values will be stored.
+ * dst->op is `GGML_OP_UPSCALE`.
+ */
+void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
+ ggml_tensor* dst);
+
+/**
+ * @brief Pads a ggml tensor to match the dimensions of the destination tensor
+ * using the CANN backend.
+ *
+ * @details This function pads the input tensor `src` so that it matches the
+ * dimensions of the destination tensor `dst`. The amount of padding
+ * is calculated based on the difference in sizes between `src` and
+ * `dst` along each dimension. The padded tensor is stored in `dst`.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor, which specifies the target dimensions for
+ * padding. dst->op is `GGML_OP_PAD`.
+ */
+void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Executes a 2D pooling operation on a ggml tensor using the CANN
+ * backend.
+ *
+ * @details This function dispatches the execution of a 2D pooling operation on
+ * the input tensor `dst`. The type of pooling (average or max) is
+ * determined by the `op` parameter, which is read from the operation
+ * parameters of `dst`. The function supports average pooling
+ * (`GGML_OP_POOL_AVG`) and max pooling (`GGML_OP_POOL_MAX`). If an
+ * invalid operation is encountered, the function asserts a failure.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor on which the pooling operation is to be
+ * performed. dst->op is `GGML_OP_POOL_2D`.
+ */
+void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Duplicates a ggml tensor using the CANN backend.
+ *
+ * @details This function duplicates the contents of the source tensor `src` to
+ * the destination tensor `dst`. The function supports various tensor
+ * types and configurations, including handling of extra data, type
+ * conversions, and special cases for contiguous and non-contiguous
+ * tensors.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the duplicated data will be stored.
+ * dst->op is `GGML_OP_DUP`
+ *
+ * @attention Only support Fp16/FP32. Not support when src and dst have
+ * different shape and dst is no-contiguous.
+ * @note: This func need to simplify.
+ */
+void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Computes the Root Mean Square (RMS) normalization of a ggml tensor
+ * using the CANN backend.
+ *
+ * @details This function applies RMS normalization to the input tensor `src`
+ * and stores the result in the destination tensor `dst`. RMS
+ * normalization involves computing the root mean square of the input
+ * tensor along a specified dimension and then dividing each element of
+ * the tensor by this value, adjusted by a small epsilon value to
+ * prevent division by zero.
+ * The operation is defined as:
+ * \f[
+ * \text{RmsNorm}\left(x_i\right)=\frac{x_i}{\text{Rms}(\mathbf{x})} g_i,
+ * \quad \text { where } \text{Rms}(\mathbf{x})=\sqrt{\frac{1}{n} \sum_{i=1}^n x_i^2+e p s}
+ * \f]
+ * `eps` is in dst->op_params.
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the normalized values will be stored.
+ * dst->op is `GGML_OP_RMS_NORM`.
+ */
+void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Applies a diagonal mask to the tensor with a specified value.
+ *
+ * @details This function creates a mask tensor filled with ones, then applies
+ * an upper triangular and lower triangular operation to it based on
+ * the number of past elements specified. Afterward, it adds the masked
+ * tensor to the destination tensor in-place.
+ *
+ * @param ctx The backend CANN context used for operations.
+ * @param dst The destination tensor where the result will be stored. dst->op is
+ * `GGML_OP_DIAG_MASK`
+ * @param value The value to use for masking.
+ */
+void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, float value);
+
+/**
+ * @brief Performs an image-to-column transformation on the input tensor.
+ *
+ * @details This function takes an input tensor and applies an image-to-column
+ * operation, converting spatial dimensions into column-like
+ * structures suitable for convolutional operations. It supports both
+ * half-precision (F16) and single-precision (F32) floating-point data
+ * types.
+ *
+ * @param ctx The backend CANN context for executing operations.
+ * @param dst The destination tensor that stores the result of the operation.
+ * dst->op is `GGML_OP_IM2COL`.
+ */
+void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Computes time step embeddings using sine and cosine functions.
+ *
+ * @details This function calculates time step embeddings by applying sine and
+ * cosine transformations to a given input tensor, which is typically
+ * used in temporal models like diffusion models or transformers to
+ * encode time information effectively.
+ *
+ * @param ctx The backend CANN context for executing operations.
+ * @param dst The destination tensor where the result of the embedding operation
+ * will be stored. dst->op is `GGML_OP_TIMESTEP_EMBEDDING`.
+ */
+void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+// @see ggml_cann_dup.
+void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Computes the softmax activation with optional masking.
+ *
+ * @details This function computes the softmax activation over the input tensor,
+ * optionally applying a mask and scaling factor. It supports both FP16
+ * and FP32 data types and can handle masking by broadcasting the mask
+ * across rows if necessary.
+ * The function performs the following steps:
+ * 1. Multiplies the input tensor by a scale factor.
+ * 2. Optionally casts the mask tensor to FP32 if it is in FP16 format.
+ * 3. Broadcasts the mask tensor if its dimensions do not match the
+ * input tensor's dimensions.
+ * 4. Adds the mask to the scaled input tensor.
+ * 5. Applies the softmax activation function along the specified
+ * dimension.
+ *
+ * @param ctx The backend CANN context for executing operations.
+ * @param dst The destination tensor where the result will be stored. dst->op is
+ * `GGML_OP_SOFTMAX`.
+ */
+void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Extracts specific rows from a tensor based on indices.
+ *
+ * @details This function retrieves rows from a source tensor src0 according to
+ * the indices provided in another tensor src1 and stores the result in
+ * a destination tensor (\p dst). It supports different data types
+ * including F32, F16, Q4_0, and Q8_0.
+ *
+ * @param ctx The backend CANN context for executing operations.
+ * @param dst The destination tensor where the extracted rows will be stored.
+ * dst->op is `GGML_OP_GET_ROWS`.
+ */
+void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Executes matrix multiplication for the given tensor.
+ *
+ * @details This function performs matrix multiplication on the source tensors
+ * associated with the destination tensor. It supports matrix
+ * multiplication F32, F16, and Q8_0.
+ *
+ * @param ctx The backend CANN context for executing operations.
+ * @param dst The destination tensor for storing the result of the matrix
+ * multiplication. dst->op is `GGML_OP_MUL_MAT`.
+ */
+void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Applies Rotary Positional Embedding (RoPE) to the input tensor.
+ *
+ * @details This function implements the RoPE mechanism, which is a method to
+ * encode positional information into sequence data, particularly
+ * useful in transformer models. It supports both F32 and F16 data
+ * types.
+ *
+ * @param ctx The backend CANN context for executing operations.
+ * @param dst The destination tensor where the RoPE-transformed data will be
+ * stored. dst->op is `GGML_OP_ROPE`.
+ *
+ * @note The function currently does not support cases where the n_dims is less
+ * than the input tensor's first dimension.
+ * @note The function currently does not support cases where the freq_factors is
+ * not NULL.
+ * @note The function currently does not support cases where the ext_factor is
+ * not equal 0.
+ * @note The function currently does not support cases where the freq_scale is
+ * not equal 1.
+ */
+void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*,
+ aclTensor*, uint64_t*, aclOpExecutor**),
+ aclnnStatus execute(void*, uint64_t, aclOpExecutor*, aclrtStream)>
+void ggml_cann_mul_div(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src0 = dst->src[0];
+ ggml_tensor* src1 = dst->src[1];
+ GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+ aclTensor* acl_src0;
+ aclTensor* acl_src1;
+ aclTensor* acl_dst;
+
+ // Need bcast
+ if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) {
+ BCAST_SHAPE(src0, src1)
+ acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0));
+ acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1));
+ acl_dst = ggml_cann_create_tensor(dst, BCAST_PARAM(src0));
+ } else {
+ acl_src0 = ggml_cann_create_tensor(src0);
+ acl_src1 = ggml_cann_create_tensor(src1);
+ acl_dst = ggml_cann_create_tensor(dst);
+ }
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(getWorkspaceSize(acl_src0, acl_src1, acl_dst, &workspaceSize,
+ &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ aclrtStream main_stream = ctx.stream();
+ ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream));
+
+ ACL_CHECK(aclDestroyTensor(acl_src0));
+ ACL_CHECK(aclDestroyTensor(acl_src1));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+// Activation functions template.
+template <aclnnStatus getWorkspaceSize(const aclTensor*, aclTensor*, uint64_t*,
+ aclOpExecutor**),
+ aclnnStatus execute(void*, uint64_t, aclOpExecutor*,
+ const aclrtStream)>
+void ggml_cann_activation(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+
+ GGML_ASSERT(src->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(getWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ aclrtStream main_stream = ctx.stream();
+ ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream));
+
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+// Activation functions template for const aclTensors.
+template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*,
+ uint64_t*, aclOpExecutor**),
+ aclnnStatus execute(void*, uint64_t, aclOpExecutor*,
+ const aclrtStream)>
+void ggml_cann_activation(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+ ggml_tensor* src = dst->src[0];
+
+ GGML_ASSERT(src->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+ uint64_t workspaceSize = 0;
+ aclOpExecutor* executor;
+ void* workspaceAddr = nullptr;
+
+ ACL_CHECK(getWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor));
+ if (workspaceSize > 0) {
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+ workspaceAddr = workspace_allocator.get();
+ }
+
+ aclrtStream main_stream = ctx.stream();
+ ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream));
+
+ ACL_CHECK(aclDestroyTensor(acl_src));
+ ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+#endif // CANN_ACLNN_OPS
diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h
new file mode 100644
index 00000000..e6a57010
--- /dev/null
+++ b/ggml/src/ggml-cann/common.h
@@ -0,0 +1,282 @@
+/*
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#ifndef CANN_COMMON_H
+#define CANN_COMMON_H
+
+#include <acl/acl.h>
+
+#include <cstdio>
+#include <iostream>
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "../include/ggml-cann.h"
+#include "../include/ggml.h"
+
+#define MATRIX_ROW_PADDING 512
+#define GGML_CANN_MAX_STREAMS 8
+
+/**
+ * @brief Handles CANN-related errors by printing an error message and
+ * terminating the program.
+ * @param stmt The statement that caused the error.
+ * @param func The function in which the error occurred.
+ * @param file The file in which the error occurred.
+ * @param line The line number at which the error occurred.
+ * @param msg The error message.
+ */
+[[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
+ const char* file, int line, const char* msg);
+
+/**
+ * @brief Checks the result of a CANN function call and invokes the error
+ * handler if the call fails.
+ * @param stmt The CANN function call to check.
+ * @param success The success code that indicates the call was successful.
+ * @param error_fn The function to call to retrieve the error message.
+ */
+#define ACL_CHECK_GEN(stmt, success, error_fn) \
+ do { \
+ int err_code = (stmt); \
+ if (err_code != (success)) { \
+ ggml_cann_error(#stmt, __func__, __FILE__, __LINE__, error_fn()); \
+ } \
+ } while (0);
+
+#define ACL_CHECK(stmt) ACL_CHECK_GEN(stmt, 0, aclGetRecentErrMsg)
+
+/**
+ * @brief Contains information about CANN devices.
+ */
+struct ggml_cann_device_info {
+ /**
+ * @brief Number of CANN devices available.
+ */
+ int32_t device_count;
+
+ /**
+ * @brief Information about a single CANN device.
+ */
+ struct cann_device_info {
+ int cc; /**< Compute capability. */
+ size_t smpb; /**< Maximum shared memory per block. */
+ bool vmm; /**< Virtual memory support. */
+ size_t vmm_granularity; /**< Granularity of virtual memory. */
+ size_t total_vram; /**< Total video RAM available on the device. */
+ };
+
+ cann_device_info devices[GGML_CANN_MAX_DEVICES] =
+ {}; /**< Array of CANN device information. */
+};
+
+const ggml_cann_device_info& ggml_cann_info();
+
+void ggml_cann_set_device(int32_t device);
+int32_t ggml_cann_get_device();
+
+/**
+ * @brief Abstract base class for memory pools used by CANN.
+ */
+struct ggml_cann_pool {
+ /**
+ * @brief Virtual destructor for the memory pool.
+ */
+ virtual ~ggml_cann_pool() = default;
+
+ /**
+ * @brief Allocates memory from the pool.
+ *
+ * @param size The size of the memory block to allocate.
+ * @param actual_size Pointer to a variable where the actual allocated size
+ * will be stored.
+ * @return Pointer to the allocated memory block.
+ */
+ virtual void* alloc(size_t size, size_t* actual_size) = 0;
+
+ /**
+ * @brief Frees a previously allocated memory block.
+ *
+ * @param ptr Pointer to the memory block to free.
+ * @param size Size of the memory block to free.
+ * @note Note that all CANN opertors are running async. Make sure memory is
+ * still avaiable before this operator finished.
+ */
+ virtual void free(void* ptr, size_t size) = 0;
+};
+
+/**
+ * @brief RAII wrapper for managing memory allocations from a CANN memory pool.
+ */
+struct ggml_cann_pool_alloc {
+ ggml_cann_pool* pool = nullptr; /**< Pointer to the memory pool. */
+ void* ptr = nullptr; /**< Pointer to the allocated memory block. */
+ size_t actual_size = 0; /**< Actual size of the allocated memory block. */
+
+ /**
+ * @brief Default constructor.
+ */
+ ggml_cann_pool_alloc() = default;
+
+ /**
+ * @brief Constructor that initializes the memory pool.
+ * @param pool Reference to the memory pool.
+ */
+ explicit ggml_cann_pool_alloc(ggml_cann_pool& pool) : pool(&pool) {}
+
+ /**
+ * @brief Constructor that initializes the memory pool and allocates memory.
+ * @param pool Reference to the memory pool.
+ * @param size Size of the memory block to allocate.
+ */
+ ggml_cann_pool_alloc(ggml_cann_pool& pool, size_t size) : pool(&pool) {
+ alloc(size);
+ }
+
+ /**
+ * @brief Destructor that frees the allocated memory block.
+ */
+ ~ggml_cann_pool_alloc() {
+ if (ptr != nullptr) {
+ pool->free(ptr, actual_size);
+ }
+ }
+
+ /**
+ * @brief Allocates memory from the pool.
+ * @param size Size of the memory block to allocate.
+ * @return Pointer to the allocated memory block.
+ */
+ void* alloc(size_t size) {
+ GGML_ASSERT(pool != nullptr);
+ GGML_ASSERT(ptr == nullptr);
+ ptr = pool->alloc(size, &this->actual_size);
+ return ptr;
+ }
+
+ /**
+ * @brief Allocates memory from a specific memory pool.
+ * @param pool Reference to the memory pool.
+ * @param size Size of the memory block to allocate.
+ * @return Pointer to the allocated memory block.
+ */
+ void* alloc(ggml_cann_pool& pool, size_t size) {
+ this->pool = &pool;
+ return alloc(size);
+ }
+
+ /**
+ * @brief Gets the pointer to the allocated memory block.
+ * @return Pointer to the allocated memory block.
+ */
+ void* get() { return ptr; }
+
+ // Deleted copy constructor
+ ggml_cann_pool_alloc(const ggml_cann_pool_alloc&) = delete;
+
+ // Deleted move constructor
+ ggml_cann_pool_alloc(ggml_cann_pool_alloc&&) = delete;
+
+ // Deleted copy assignment operator
+ ggml_cann_pool_alloc& operator=(const ggml_cann_pool_alloc&) = delete;
+
+ // Deleted move assignment operator
+ ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete;
+};
+
+/**
+ * @brief Context for managing CANN backend operations.
+ */
+struct ggml_backend_cann_context {
+ int32_t device; /**< Device ID. */
+ std::string name; /**< Name of the device. */
+ aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
+
+ aclrtStream streams[GGML_CANN_MAX_STREAMS] = {
+ {nullptr}}; /**< Array of streams for the device. */
+
+ /**
+ * @brief Constructor for initializing the context with a given device.
+ * @param device Device ID.
+ */
+ explicit ggml_backend_cann_context(int device)
+ : device(device), name("CANN" + std::to_string(device)) {}
+
+ /**
+ * @brief Destructor for cleaning up resources.
+ */
+ ~ggml_backend_cann_context() {
+ if (copy_event != nullptr) {
+ ACL_CHECK(aclrtDestroyEvent(copy_event));
+ }
+ for (int i = 0; i < GGML_CANN_MAX_STREAMS; ++i) {
+ if (streams[i] != nullptr) {
+ ACL_CHECK(aclrtDestroyStream(streams[i]));
+ }
+ }
+ }
+
+ /**
+ * @brief Get or create a stream for a given index.
+ * @param stream Index of the stream.
+ * @return The stream corresponding to the given index.
+ */
+ aclrtStream stream(int stream) {
+ if (streams[stream] == nullptr) {
+ ggml_cann_set_device(device);
+ ACL_CHECK(aclrtCreateStream(&streams[stream]));
+ }
+ return streams[stream];
+ }
+
+ /**
+ * @brief Get or create the default stream (index 0).
+ * @return The default stream.
+ */
+ aclrtStream stream() { return stream(0); }
+
+ // TODO: each stream should have a memory pool.
+ std::unique_ptr<ggml_cann_pool>
+ mem_pool; /**< Memory pool for the device. */
+
+ /**
+ * @brief Create a new memory pool for a given device.
+ * @param device Device ID.
+ * @return A unique pointer to the new memory pool.
+ */
+ static std::unique_ptr<ggml_cann_pool> new_pool_for_device(int device);
+
+ /**
+ * @brief Get or create the memory pool for the context.
+ * @return Reference to the memory pool.
+ */
+ ggml_cann_pool& pool() {
+ if (mem_pool == nullptr) {
+ mem_pool = new_pool_for_device(device);
+ }
+ return *mem_pool;
+ }
+};
+
+#endif // CANN_COMMON_H
diff --git a/ggml/src/ggml-cann/kernels/CMakeLists.txt b/ggml/src/ggml-cann/kernels/CMakeLists.txt
new file mode 100644
index 00000000..f12a4d43
--- /dev/null
+++ b/ggml/src/ggml-cann/kernels/CMakeLists.txt
@@ -0,0 +1,32 @@
+if (NOT SOC_TYPE)
+ set (SOC_TYPE "Ascend910B3")
+endif()
+
+file(GLOB SRC_FILES
+ get_row_f32.cpp
+ get_row_f16.cpp
+ get_row_q4_0.cpp
+ get_row_q8_0.cpp
+ quantize_f32_q8_0.cpp
+ quantize_f16_q8_0.cpp
+ dup.cpp
+)
+
+string(TOLOWER ${SOC_TYPE} SOC_VERSION)
+set(ASCEND_CANN_PACKAGE_PATH ${CANN_INSTALL_DIR})
+set(RUN_MODE "npu" CACHE STRING "run mode: npu/sim")
+
+if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
+ set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
+elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake)
+ set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake)
+else()
+ message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the compiler package is installed.")
+endif()
+include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
+
+ascendc_library(ascendc_kernels STATIC
+ ${SRC_FILES}
+)
+
+#ascendc_compile_definitions(ascendc_kernels PRIVATE -DASCENDC_DUMP)
diff --git a/ggml/src/ggml-cann/kernels/ascendc_kernels.h b/ggml/src/ggml-cann/kernels/ascendc_kernels.h
new file mode 100644
index 00000000..bf891475
--- /dev/null
+++ b/ggml/src/ggml-cann/kernels/ascendc_kernels.h
@@ -0,0 +1,17 @@
+#ifndef ASCENDC_KERNELS_H
+#define ASCENDC_KERNELS_H
+
+#include "aclrtlaunch_ascendc_get_row_f32.h"
+#include "aclrtlaunch_ascendc_get_row_f16.h"
+#include "aclrtlaunch_ascendc_get_row_q8_0.h"
+#include "aclrtlaunch_ascendc_get_row_q4_0.h"
+
+#include "aclrtlaunch_ascendc_quantize_f32_q8_0.h"
+#include "aclrtlaunch_ascendc_quantize_f16_q8_0.h"
+
+#include "aclrtlaunch_ascendc_dup_by_rows_fp16.h"
+#include "aclrtlaunch_ascendc_dup_by_rows_fp32.h"
+#include "aclrtlaunch_ascendc_dup_by_rows_fp32_to_fp16.h"
+#include "aclrtlaunch_ascendc_dup_by_rows_fp16_to_fp32.h"
+
+#endif // ASCENDC_KERNELS_H
diff --git a/ggml/src/ggml-cann/kernels/dup.cpp b/ggml/src/ggml-cann/kernels/dup.cpp
new file mode 100644
index 00000000..e2c65115
--- /dev/null
+++ b/ggml/src/ggml-cann/kernels/dup.cpp
@@ -0,0 +1,223 @@
+#include "kernel_operator.h"
+
+#include <cmath>
+
+using namespace AscendC;
+
+#define BUFFER_NUM 2
+
+template <typename SRC_T, typename DST_T>
+class DupByRows {
+ public:
+ __aicore__ inline DupByRows() {}
+ __aicore__ inline void init(GM_ADDR src, GM_ADDR dst, int64_t *input_ne_ub,
+ size_t *input_nb_ub) {
+ /* Dup by rows when src is contigous on first dimension and dst is
+ contiguous, each kernel process one row.
+ */
+
+ // Input has four dims.
+ int64_t op_block_num = GetBlockNum();
+ int64_t op_block_idx = GetBlockIdx();
+
+ // param
+ num_rows = input_ne_ub[1] * input_ne_ub[2] * input_ne_ub[3];
+ num_elem = input_ne_ub[0];
+
+ // index for (ne[1], ne[2], ne[3]): (idx_ne1, idx_ne2, idx_ne3)
+ idx_ne3 = op_block_idx / (input_ne_ub[1] * input_ne_ub[2]);
+ idx_ne2 = (op_block_idx - idx_ne3 * (input_ne_ub[1] * input_ne_ub[2]))
+ / (input_ne_ub[1]);
+ idx_ne1 = op_block_idx - idx_ne3 * (input_ne_ub[1] * input_ne_ub[2])
+ - idx_ne2 * input_ne_ub[1];
+
+ // src may not contiguous in dim [1,2,3], so stride decited by ne&nb
+ src_stride = input_nb_ub[3] * idx_ne3 + input_nb_ub[2] * idx_ne2
+ + input_nb_ub[1] * idx_ne1;
+
+ // dst is contiguous
+ dst_stride = op_block_idx * (input_ne_ub[0] * sizeof(DST_T));
+
+ src_gm.SetGlobalBuffer(reinterpret_cast<__gm__ SRC_T *>(src +
+ src_stride));
+ dst_gm.SetGlobalBuffer(reinterpret_cast<__gm__ DST_T *>(dst +
+ dst_stride));
+
+ pipe.InitBuffer(src_queue, BUFFER_NUM, (sizeof(SRC_T) * num_elem +
+ 32 - 1) / 32 * 32);
+ pipe.InitBuffer(dst_queue, BUFFER_NUM, (sizeof(DST_T) * num_elem +
+ 32 - 1) / 32 * 32);
+ }
+
+ __aicore__ inline void copy_in() {
+ LocalTensor<SRC_T> src_local = src_queue.AllocTensor<SRC_T>();
+
+ DataCopyExtParams dataCopyParams;
+ dataCopyParams.blockCount = 1;
+ dataCopyParams.blockLen = num_elem * sizeof(SRC_T);
+ DataCopyPadExtParams<SRC_T> padParams;
+ DataCopyPad(src_local, src_gm, dataCopyParams, padParams);
+
+ src_queue.EnQue(src_local);
+ }
+
+ __aicore__ inline void copy_out() {
+ LocalTensor<DST_T> dst_local = dst_queue.DeQue<DST_T>();
+
+ DataCopyExtParams dataCopyParams;
+ dataCopyParams.blockCount = 1;
+ dataCopyParams.blockLen = num_elem * sizeof(DST_T);
+ DataCopyPad(dst_gm, dst_local, dataCopyParams);
+
+ dst_queue.FreeTensor(dst_local);
+ }
+
+ __aicore__ inline void dup() {
+ // main process, copy one row data from src to dst.
+ copy_in();
+
+ LocalTensor<SRC_T> src_local = src_queue.DeQue<SRC_T>();
+ LocalTensor<DST_T> dst_local = dst_queue.AllocTensor<DST_T>();
+
+ int32_t BLOCK_NUM = 32 / sizeof(DST_T);
+ DataCopy(dst_local, src_local, (num_elem + BLOCK_NUM - 1)
+ / BLOCK_NUM * BLOCK_NUM);
+ dst_queue.EnQue<DST_T>(dst_local);
+
+ src_queue.FreeTensor(src_local);
+ copy_out();
+ }
+
+ __aicore__ inline void dup_with_cast() {
+ // main process, copy one row data from src to dst.
+ // cast dtype from src to dst.
+ copy_in();
+
+ LocalTensor<SRC_T> src_local = src_queue.DeQue<SRC_T>();
+ LocalTensor<DST_T> dst_local = dst_queue.AllocTensor<DST_T>();
+
+ Cast(dst_local, src_local, RoundMode::CAST_NONE, num_elem);
+ dst_queue.EnQue<DST_T>(dst_local);
+
+ src_queue.FreeTensor(src_local);
+ copy_out();
+ }
+
+ private:
+
+ TPipe pipe;
+ GlobalTensor<SRC_T> src_gm;
+ GlobalTensor<DST_T> dst_gm;
+
+ int64_t num_rows;
+ int64_t num_elem;
+ int64_t idx_ne3;
+ int64_t idx_ne2;
+ int64_t idx_ne1;
+ int64_t src_stride;
+ int64_t dst_stride;
+
+ TQue<QuePosition::VECIN, BUFFER_NUM> src_queue;
+ TQue<QuePosition::VECOUT, BUFFER_NUM> dst_queue;
+};
+
+template <typename T>
+__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
+ auto gm_ptr = (__gm__ uint8_t *)gm;
+ auto ub_ptr = (uint8_t *)(ub);
+ for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
+ *ub_ptr = *gm_ptr;
+ }
+}
+
+extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp16(
+ GM_ADDR src_gm,
+ GM_ADDR dst_gm,
+ GM_ADDR input_ne_gm,
+ GM_ADDR input_nb_gm,
+ GM_ADDR output_ne_gm,
+ GM_ADDR output_nb_gm) {
+
+ int64_t input_ne_ub[4];
+ size_t input_nb_ub[4];
+ int64_t output_ne_ub[4];
+ size_t output_nb_ub[4];
+
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+ DupByRows<half, half> op;
+ op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
+ op.dup();
+}
+
+extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32(
+ GM_ADDR src_gm,
+ GM_ADDR dst_gm,
+ GM_ADDR input_ne_gm,
+ GM_ADDR input_nb_gm,
+ GM_ADDR output_ne_gm,
+ GM_ADDR output_nb_gm) {
+ int64_t input_ne_ub[4];
+ size_t input_nb_ub[4];
+ int64_t output_ne_ub[4];
+ size_t output_nb_ub[4];
+
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+ DupByRows<float_t, float_t> op;
+ op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
+ op.dup();
+}
+
+extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32_to_fp16(
+ GM_ADDR src_gm,
+ GM_ADDR dst_gm,
+ GM_ADDR input_ne_gm,
+ GM_ADDR input_nb_gm,
+ GM_ADDR output_ne_gm,
+ GM_ADDR output_nb_gm) {
+
+ int64_t input_ne_ub[4];
+ size_t input_nb_ub[4];
+ int64_t output_ne_ub[4];
+ size_t output_nb_ub[4];
+
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+ DupByRows<float_t, half> op;
+ op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
+ op.dup_with_cast();
+}
+
+extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp16_to_fp32(
+ GM_ADDR src_gm,
+ GM_ADDR dst_gm,
+ GM_ADDR input_ne_gm,
+ GM_ADDR input_nb_gm,
+ GM_ADDR output_ne_gm,
+ GM_ADDR output_nb_gm) {
+
+ // copy params from gm to ub.
+ int64_t input_ne_ub[4];
+ size_t input_nb_ub[4];
+ int64_t output_ne_ub[4];
+ size_t output_nb_ub[4];
+
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+ DupByRows<half, float_t> op;
+ op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
+ op.dup_with_cast();
+}
diff --git a/ggml/src/ggml-cann/kernels/get_row_f16.cpp b/ggml/src/ggml-cann/kernels/get_row_f16.cpp
new file mode 100644
index 00000000..c704b5b2
--- /dev/null
+++ b/ggml/src/ggml-cann/kernels/get_row_f16.cpp
@@ -0,0 +1,186 @@
+#include "kernel_operator.h"
+
+// optimize me. Use template to avoid copy code.
+using namespace AscendC;
+
+#define BUFFER_NUM 2
+
+class GET_ROW_F16 {
+ public:
+ __aicore__ inline GET_ROW_F16() {}
+ __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
+ int64_t *input_ne_ub, size_t *input_nb_ub,
+ int64_t *indices_ne_ub, size_t *indices_nb_ub,
+ int64_t *output_ne_ub, size_t *output_nb_ub) {
+ // TODO, use template for F16/f32
+ int64_t op_block_num = GetBlockNum();
+ int64_t op_block_idx = GetBlockIdx();
+
+ for (int i = 0; i < 4; i++) {
+ input_ne[i] = input_ne_ub[i];
+ input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
+
+ indices_ne[i] = indices_ne_ub[i];
+ indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
+
+ output_ne[i] = output_ne_ub[i];
+ output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
+ }
+
+ // Indices has two dims. n_elements = all rows should get.
+ // dr, all rows should this thread get.
+ uint64_t n_elements =
+ indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
+ dr = n_elements / op_block_num;
+
+ uint64_t tails = n_elements % op_block_num;
+ if (op_block_idx < tails) {
+ dr += 1;
+ ir = dr * op_block_idx;
+ } else {
+ ir = dr * op_block_idx + tails;
+ }
+
+ input_gm.SetGlobalBuffer((__gm__ half *)input);
+ indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
+ output_gm.SetGlobalBuffer((__gm__ float *)output);
+
+ uint64_t input_local_buffer_size = ((input_ne[0] * sizeof(half) + 31)
+ & ~31);
+ uint64_t output_local_buffer_size = ((input_ne[0] * sizeof(float) + 31)
+ & ~31);
+
+ local_buffer_elems = input_local_buffer_size / sizeof(half);
+
+ // TODO, consider long row that can't put in UB.
+ // All data should asign to 32. It's ok because all data is align to 32.
+ pipe.InitBuffer(input_queue, BUFFER_NUM, input_local_buffer_size);
+ pipe.InitBuffer(output_queue, BUFFER_NUM, output_local_buffer_size);
+ }
+
+ __aicore__ inline void copy_in(uint32_t offset, size_t len) {
+ LocalTensor<half> input_local = input_queue.AllocTensor<half>();
+ size_t tail = len % 32;
+ len = len & ~31;
+ DataCopy(input_local, input_gm[offset], len);
+ if(tail != 0) {
+ DataCopyExtParams dataCopyParams;
+ dataCopyParams.blockCount = 1;
+ dataCopyParams.blockLen = tail * sizeof(half);
+ DataCopyPadExtParams<half> padParams;
+ DataCopyPad(input_local[len], input_gm[offset + len],
+ dataCopyParams, padParams);
+ }
+ input_queue.EnQue(input_local);
+ }
+
+ __aicore__ inline void copy_out(uint32_t offset, size_t len) {
+ LocalTensor<float> output_local = output_queue.DeQue<float>();
+ size_t tail = len % 32;
+ len = len & ~31;
+ DataCopy(output_gm[offset], output_local, len);
+ if(tail != 0) {
+ DataCopyExtParams dataCopyParams;
+ dataCopyParams.blockCount = 1;
+ dataCopyParams.blockLen = tail * sizeof(float);
+ DataCopyPad(output_gm[offset + len], output_local[len],
+ dataCopyParams);
+ }
+ output_queue.FreeTensor(output_local);
+ }
+
+ __aicore__ inline void calculate_row(int64_t idx) {
+ const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
+ const int64_t indices_ne1_idx =
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
+ indices_ne[0];
+ const int64_t indices_ne0_idx =
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
+ indices_ne1_idx * indices_ne[0]);
+
+ const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
+ indices_ne1_idx * indices_stride[1] +
+ indices_ne2_idx * indices_stride[2];
+ const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
+
+ const int64_t input_offset = selected_row_idx * input_stride[1] +
+ indices_ne1_idx * input_stride[2] +
+ indices_ne2_idx * input_stride[3];
+
+ const int64_t output_offset = indices_ne0_idx * output_stride[1] +
+ indices_ne1_idx * output_stride[2] +
+ indices_ne2_idx * output_stride[3];
+
+ copy_in(input_offset, input_ne[0]);
+ LocalTensor<half> input_local = input_queue.DeQue<half>();
+ LocalTensor<float> output_local = output_queue.AllocTensor<float>();
+
+ Cast(output_local, input_local, RoundMode::CAST_NONE,
+ local_buffer_elems);
+ output_queue.EnQue(output_local);
+ copy_out(output_offset, input_ne[0]);
+
+ input_queue.FreeTensor(input_local);
+ }
+
+ __aicore__ inline void calculate() {
+ for (int64_t i = ir; i < ir + dr; i++) {
+ calculate_row(i);
+ }
+ }
+
+ private:
+ int64_t input_ne[4];
+ size_t input_stride[4];
+
+ int64_t indices_ne[4];
+ size_t indices_stride[4];
+
+ int64_t output_ne[4];
+ size_t output_stride[4];
+
+ size_t local_buffer_elems;
+
+ int64_t ir;
+ int64_t dr;
+
+ TPipe pipe;
+ GlobalTensor<half> input_gm;
+ GlobalTensor<int32_t> indices_gm;
+ GlobalTensor<float> output_gm;
+ TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
+ TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
+};
+
+template <typename T>
+__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
+ auto gm_ptr = (__gm__ uint8_t *)gm;
+ auto ub_ptr = (uint8_t *)(ub);
+ for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
+ *ub_ptr = *gm_ptr;
+ }
+}
+
+extern "C" __global__ __aicore__ void ascendc_get_row_f16(
+ GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
+ GM_ADDR input_ne_gm, GM_ADDR input_nb_gm, GM_ADDR indices_ne_gm,
+ GM_ADDR indices_nb_gm, GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
+ int64_t input_ne_ub[4];
+ size_t input_nb_ub[4];
+ int64_t indices_ne_ub[4];
+ size_t indices_nb_ub[4];
+ int64_t output_ne_ub[4];
+ size_t output_nb_ub[4];
+
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
+ copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
+ copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+ GET_ROW_F16 op;
+ op.init(input_gm, indices_gm, output_gm, input_ne_ub, input_nb_ub,
+ indices_ne_ub, indices_nb_ub, output_ne_ub, output_nb_ub);
+ op.calculate();
+}
diff --git a/ggml/src/ggml-cann/kernels/get_row_f32.cpp b/ggml/src/ggml-cann/kernels/get_row_f32.cpp
new file mode 100644
index 00000000..9db080af
--- /dev/null
+++ b/ggml/src/ggml-cann/kernels/get_row_f32.cpp
@@ -0,0 +1,180 @@
+#include "kernel_operator.h"
+
+// optimize me. Use template to avoid copy code.
+using namespace AscendC;
+
+#define BUFFER_NUM 2
+
+class GET_ROW_F32 {
+ public:
+ __aicore__ inline GET_ROW_F32() {}
+ __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
+ int64_t *input_ne_ub, size_t *input_nb_ub,
+ int64_t *indices_ne_ub, size_t *indices_nb_ub,
+ int64_t *output_ne_ub, size_t *output_nb_ub) {
+ int64_t op_block_num = GetBlockNum();
+ int64_t op_block_idx = GetBlockIdx();
+
+ for (int i = 0; i < 4; i++) {
+ input_ne[i] = input_ne_ub[i];
+ input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
+
+ indices_ne[i] = indices_ne_ub[i];
+ indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
+
+ output_ne[i] = output_ne_ub[i];
+ output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
+ }
+
+ // Indices has two dims. n_elements = all rows should get.
+ // dr, all rows should this thread get.
+ uint64_t n_elements =
+ indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
+ dr = n_elements / op_block_num;
+
+ uint64_t tails = n_elements % op_block_num;
+ if (op_block_idx < tails) {
+ dr += 1;
+ ir = dr * op_block_idx;
+ } else {
+ ir = dr * op_block_idx + tails;
+ }
+
+ input_gm.SetGlobalBuffer((__gm__ float *)input);
+ indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
+ output_gm.SetGlobalBuffer((__gm__ float *)output);
+
+ uint64_t local_buffer_size = ((input_ne[0] * sizeof(float) + 31) & ~31);
+ local_buffer_elems = local_buffer_size / sizeof(float);
+
+ // TODO, consider long row that can't put in UB.
+ // All data should asign to 32. It's ok because all data is align to 32.
+ pipe.InitBuffer(input_queue, BUFFER_NUM, local_buffer_size);
+ pipe.InitBuffer(output_queue, BUFFER_NUM, local_buffer_size);
+ }
+
+ __aicore__ inline void copy_in(uint32_t offset, size_t len) {
+ LocalTensor<float> input_local = input_queue.AllocTensor<float>();
+ size_t tail = len % 32;
+ len = len & ~31;
+ DataCopy(input_local, input_gm[offset], len);
+ if(tail != 0) {
+ DataCopyExtParams dataCopyParams;
+ dataCopyParams.blockCount = 1;
+ dataCopyParams.blockLen = tail * sizeof(float);
+ DataCopyPadExtParams<float> padParams;
+ DataCopyPad(input_local[len], input_gm[offset + len],
+ dataCopyParams, padParams);
+ }
+ input_queue.EnQue(input_local);
+ }
+
+ __aicore__ inline void copy_out(uint32_t offset, size_t len) {
+ LocalTensor<float> output_local = output_queue.DeQue<float>();
+ size_t tail = len % 32;
+ len = len & ~31;
+ DataCopy(output_gm[offset], output_local, len);
+ if(tail != 0) {
+ DataCopyExtParams dataCopyParams;
+ dataCopyParams.blockCount = 1;
+ dataCopyParams.blockLen = tail * sizeof(float);
+ DataCopyPad(output_gm[offset + len], output_local[len],
+ dataCopyParams);
+ }
+ output_queue.FreeTensor(output_local);
+ }
+
+ __aicore__ inline void calculate_row(int64_t idx) {
+ const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
+ const int64_t indices_ne1_idx =
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
+ indices_ne[0];
+ const int64_t indices_ne0_idx =
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
+ indices_ne1_idx * indices_ne[0]);
+
+ const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
+ indices_ne1_idx * indices_stride[1] +
+ indices_ne2_idx * indices_stride[2];
+ const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
+
+ const int64_t input_offset = selected_row_idx * input_stride[1] +
+ indices_ne1_idx * input_stride[2] +
+ indices_ne2_idx * input_stride[3];
+
+ const int64_t output_offset = indices_ne0_idx * output_stride[1] +
+ indices_ne1_idx * output_stride[2] +
+ indices_ne2_idx * output_stride[3];
+
+ copy_in(input_offset, input_ne[0]);
+ LocalTensor<float> input_local = input_queue.DeQue<float>();
+ LocalTensor<float> output_local = output_queue.AllocTensor<float>();
+
+ DataCopy(output_local, input_local, local_buffer_elems);
+ output_queue.EnQue(output_local);
+ copy_out(output_offset, input_ne[0]);
+
+ input_queue.FreeTensor(input_local);
+ }
+
+ __aicore__ inline void calculate() {
+ for (int64_t i = ir; i < ir + dr; i++) {
+ calculate_row(i);
+ }
+ }
+
+ private:
+ int64_t input_ne[4];
+ size_t input_stride[4];
+
+ int64_t indices_ne[4];
+ size_t indices_stride[4];
+
+ int64_t output_ne[4];
+ size_t output_stride[4];
+
+ size_t local_buffer_elems;
+
+ int64_t ir;
+ int64_t dr;
+
+ TPipe pipe;
+ GlobalTensor<float> input_gm;
+ GlobalTensor<int32_t> indices_gm;
+ GlobalTensor<float> output_gm;
+ TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
+ TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
+};
+
+template <typename T>
+__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
+ auto gm_ptr = (__gm__ uint8_t *)gm;
+ auto ub_ptr = (uint8_t *)(ub);
+ for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
+ *ub_ptr = *gm_ptr;
+ }
+}
+
+extern "C" __global__ __aicore__ void ascendc_get_row_f32(
+ GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
+ GM_ADDR input_ne_gm, GM_ADDR input_nb_gm, GM_ADDR indices_ne_gm,
+ GM_ADDR indices_nb_gm, GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
+ int64_t input_ne_ub[4];
+ size_t input_nb_ub[4];
+ int64_t indices_ne_ub[4];
+ size_t indices_nb_ub[4];
+ int64_t output_ne_ub[4];
+ size_t output_nb_ub[4];
+
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
+ copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
+ copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+ GET_ROW_F32 op;
+ op.init(input_gm, indices_gm, output_gm, input_ne_ub, input_nb_ub,
+ indices_ne_ub, indices_nb_ub, output_ne_ub, output_nb_ub);
+ op.calculate();
+}
diff --git a/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp b/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp
new file mode 100644
index 00000000..a80bfeec
--- /dev/null
+++ b/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp
@@ -0,0 +1,193 @@
+#include "kernel_operator.h"
+
+// optimize me. Use template to avoid copy code.
+using namespace AscendC;
+
+#define BUFFER_NUM 2
+
+#define QK4_0 32
+
+class GET_ROW_Q4_0 {
+ public:
+ __aicore__ inline GET_ROW_Q4_0() {}
+ __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
+ int64_t *input_ne_ub, int64_t *indices_ne_ub,
+ size_t *indices_nb_ub, int64_t *output_ne_ub,
+ size_t *output_nb_ub) {
+ int64_t op_block_num = GetBlockNum();
+ int64_t op_block_idx = GetBlockIdx();
+
+ for (int i = 0; i < 4; i++) {
+ input_ne[i] = input_ne_ub[i];
+ indices_ne[i] = indices_ne_ub[i];
+ indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
+ scale_ne[i] = input_ne_ub[i];
+ output_ne[i] = output_ne_ub[i];
+ output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
+ }
+
+ // one scale for a group.
+ scale_ne[0] /= QK4_0;
+
+ input_stride[0] = 1;
+ scale_stride[0] = 1;
+ output_stride[0] = 1;
+ for (int i = 1; i < 4; i++) {
+ input_stride[i] = input_stride[i - 1] * input_ne[i - 1];
+ scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
+ }
+
+ group_size_in_row = input_ne[0] / QK4_0;
+ int64_t scale_offset = input_ne[0] * input_ne[1] * input_ne[2] *
+ input_ne[3] / 2;
+
+ // Indices has two dims. n_elements = all rows should get.
+ // dr, all rows should this thread get.
+ uint64_t n_elements =
+ indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
+ dr = n_elements / op_block_num;
+
+ uint64_t tails = n_elements % op_block_num;
+ if (op_block_idx < tails) {
+ dr += 1;
+ ir = dr * op_block_idx;
+ } else {
+ ir = dr * op_block_idx + tails;
+ }
+
+ input_gm.SetGlobalBuffer((__gm__ int4b_t *)input);
+ scale_gm.SetGlobalBuffer((__gm__ half *)(input + scale_offset));
+ indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
+ output_gm.SetGlobalBuffer((__gm__ float *)output);
+
+ pipe.InitBuffer(input_queue, BUFFER_NUM, QK4_0 * sizeof(int4b_t));
+ pipe.InitBuffer(cast_queue, BUFFER_NUM, QK4_0 * sizeof(half));
+ pipe.InitBuffer(output_queue, BUFFER_NUM, QK4_0 * sizeof(float));
+ }
+
+ __aicore__ inline void copy_in(uint32_t offset) {
+ LocalTensor<int4b_t> input_local = input_queue.AllocTensor<int4b_t>();
+ // 32 * sizeof(int4b_t) = 16, which is not aligned to 32, why no error?
+ DataCopy(input_local, input_gm[offset], QK4_0);
+ input_queue.EnQue(input_local);
+ }
+
+ __aicore__ inline void copy_out(uint32_t offset) {
+ LocalTensor<float> output_local = output_queue.DeQue<float>();
+ DataCopy(output_gm[offset], output_local, QK4_0);
+ output_queue.FreeTensor(output_local);
+ }
+
+ __aicore__ inline void calculate_group(int64_t idx, int64_t group) {
+ const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
+ const int64_t indices_ne1_idx =
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
+ indices_ne[0];
+ const int64_t indices_ne0_idx =
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
+ indices_ne1_idx * indices_ne[0]);
+
+ const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
+ indices_ne1_idx * indices_stride[1] +
+ indices_ne2_idx * indices_stride[2];
+ const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
+
+ const int64_t input_offset = selected_row_idx * input_stride[1] +
+ indices_ne1_idx * input_stride[2] +
+ indices_ne2_idx * input_stride[3] +
+ group * QK4_0;
+ const int64_t scale_offset = selected_row_idx * scale_stride[1] +
+ indices_ne1_idx * scale_stride[2] +
+ indices_ne2_idx * scale_stride[3] + group;
+ const int64_t output_offset = indices_ne0_idx * output_stride[1] +
+ indices_ne1_idx * output_stride[2] +
+ indices_ne2_idx * output_stride[3] +
+ group * QK4_0;
+
+ copy_in(input_offset);
+ LocalTensor<int4b_t> input_local = input_queue.DeQue<int4b_t>();
+ LocalTensor<half> cast_local = cast_queue.AllocTensor<half>();
+ LocalTensor<float> output_local = output_queue.AllocTensor<float>();
+
+ // TODO: cast more data to speed up.
+ Cast(cast_local, input_local, RoundMode::CAST_NONE, QK4_0);
+ Cast(output_local, cast_local, RoundMode::CAST_NONE, QK4_0);
+
+ // Only mul need compile by group.
+ half scale = scale_gm.GetValue(scale_offset);
+
+ Muls(output_local, output_local, (float)scale, QK4_0);
+
+ input_queue.FreeTensor(input_local);
+ cast_queue.FreeTensor(cast_local);
+ output_queue.EnQue(output_local);
+
+ copy_out(output_offset);
+ }
+
+ __aicore__ inline void calculate() {
+ for (int64_t i = ir; i < ir + dr; i++) {
+ for (int64_t j = 0; j < group_size_in_row; j++) {
+ calculate_group(i, j);
+ }
+ }
+ }
+
+ private:
+ int64_t input_ne[4];
+ size_t input_stride[4];
+
+ int64_t scale_ne[4];
+ size_t scale_stride[4];
+
+ int64_t indices_ne[4];
+ size_t indices_stride[4];
+
+ int64_t output_ne[4];
+ size_t output_stride[4];
+
+ int64_t ir;
+ int64_t dr;
+
+ int64_t group_size_in_row;
+
+ TPipe pipe;
+ GlobalTensor<int4b_t> input_gm;
+ GlobalTensor<half> scale_gm;
+ GlobalTensor<int32_t> indices_gm;
+ GlobalTensor<float> output_gm;
+ TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
+ TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
+ TQue<QuePosition::VECIN, BUFFER_NUM> cast_queue;
+};
+
+template <typename T>
+__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
+ auto gm_ptr = (__gm__ uint8_t *)gm;
+ auto ub_ptr = (uint8_t *)(ub);
+ for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
+ *ub_ptr = *gm_ptr;
+ }
+}
+
+extern "C" __global__ __aicore__ void ascendc_get_row_q4_0(
+ GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
+ GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm,
+ GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
+ int64_t input_ne_ub[4];
+ int64_t indices_ne_ub[4];
+ size_t indices_nb_ub[4];
+ int64_t output_ne_ub[4];
+ size_t output_nb_ub[4];
+
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
+ copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
+ copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+ GET_ROW_Q4_0 op;
+ op.init(input_gm, indices_gm, output_gm, input_ne_ub, indices_ne_ub,
+ indices_nb_ub, output_ne_ub, output_nb_ub);
+ op.calculate();
+}
diff --git a/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp b/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp
new file mode 100644
index 00000000..ba9ab3c0
--- /dev/null
+++ b/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp
@@ -0,0 +1,191 @@
+#include "kernel_operator.h"
+
+// optimize me. Use template to avoid copy code.
+using namespace AscendC;
+
+#define BUFFER_NUM 2
+
+#define QK8_0 32
+
+class GET_ROW_Q8_0 {
+ public:
+ __aicore__ inline GET_ROW_Q8_0() {}
+ __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
+ int64_t *input_ne_ub, int64_t *indices_ne_ub,
+ size_t *indices_nb_ub, int64_t *output_ne_ub,
+ size_t *output_nb_ub) {
+ int64_t op_block_num = GetBlockNum();
+ int64_t op_block_idx = GetBlockIdx();
+
+ for (int i = 0; i < 4; i++) {
+ input_ne[i] = input_ne_ub[i];
+ indices_ne[i] = indices_ne_ub[i];
+ indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
+ scale_ne[i] = input_ne_ub[i];
+ output_ne[i] = output_ne_ub[i];
+ output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
+ }
+
+ // one scale for a group.
+ scale_ne[0] /= QK8_0;
+
+ input_stride[0] = 1;
+ scale_stride[0] = 1;
+ output_stride[0] = 1;
+ for (int i = 1; i < 4; i++) {
+ input_stride[i] = input_stride[i - 1] * input_ne[i - 1];
+ scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
+ }
+
+ group_size_in_row = input_ne[0] / QK8_0;
+ int64_t scale_offset = input_ne[0] * input_ne[1] * input_ne[2] *
+ input_ne[3] * sizeof(int8_t);
+
+ // Indices has two dims. n_elements = all rows should get.
+ // dr, all rows should this thread get.
+ uint64_t n_elements =
+ indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
+ dr = n_elements / op_block_num;
+
+ uint64_t tails = n_elements % op_block_num;
+ if (op_block_idx < tails) {
+ dr += 1;
+ ir = dr * op_block_idx;
+ } else {
+ ir = dr * op_block_idx + tails;
+ }
+
+ input_gm.SetGlobalBuffer((__gm__ int8_t *)input);
+ scale_gm.SetGlobalBuffer((__gm__ half *)(input + scale_offset));
+ indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
+ output_gm.SetGlobalBuffer((__gm__ float *)output);
+
+ pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t));
+ pipe.InitBuffer(cast_queue, BUFFER_NUM, QK8_0 * sizeof(half));
+ pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(float));
+ }
+
+ __aicore__ inline void copy_in(uint32_t offset) {
+ LocalTensor<int8_t> input_local = input_queue.AllocTensor<int8_t>();
+ DataCopy(input_local, input_gm[offset], QK8_0);
+ input_queue.EnQue(input_local);
+ }
+
+ __aicore__ inline void copy_out(uint32_t offset) {
+ LocalTensor<float> output_local = output_queue.DeQue<float>();
+ DataCopy(output_gm[offset], output_local, QK8_0);
+ output_queue.FreeTensor(output_local);
+ }
+
+ __aicore__ inline void calculate_group(int64_t idx, int64_t group) {
+ const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
+ const int64_t indices_ne1_idx =
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
+ indices_ne[0];
+ const int64_t indices_ne0_idx =
+ (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
+ indices_ne1_idx * indices_ne[0]);
+
+ const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
+ indices_ne1_idx * indices_stride[1] +
+ indices_ne2_idx * indices_stride[2];
+ const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
+
+ const int64_t input_offset = selected_row_idx * input_stride[1] +
+ indices_ne1_idx * input_stride[2] +
+ indices_ne2_idx * input_stride[3] +
+ group * QK8_0;
+ const int64_t scale_offset = selected_row_idx * scale_stride[1] +
+ indices_ne1_idx * scale_stride[2] +
+ indices_ne2_idx * scale_stride[3] + group;
+ const int64_t output_offset = indices_ne0_idx * output_stride[1] +
+ indices_ne1_idx * output_stride[2] +
+ indices_ne2_idx * output_stride[3] +
+ group * QK8_0;
+
+ copy_in(input_offset);
+ LocalTensor<int8_t> input_local = input_queue.DeQue<int8_t>();
+ LocalTensor<half> cast_local = cast_queue.AllocTensor<half>();
+ LocalTensor<float> output_local = output_queue.AllocTensor<float>();
+
+ // TODO: cast more data to speed up.
+ Cast(cast_local, input_local, RoundMode::CAST_NONE, QK8_0);
+ Cast(output_local, cast_local, RoundMode::CAST_NONE, QK8_0);
+
+ // Only mul need compile by group.
+ half scale = scale_gm.GetValue(scale_offset);
+ Muls(output_local, output_local, (float)scale, QK8_0);
+
+ input_queue.FreeTensor(input_local);
+ cast_queue.FreeTensor(cast_local);
+ output_queue.EnQue(output_local);
+
+ copy_out(output_offset);
+ }
+
+ __aicore__ inline void calculate() {
+ for (int64_t i = ir; i < ir + dr; i++) {
+ for (int64_t j = 0; j < group_size_in_row; j++) {
+ calculate_group(i, j);
+ }
+ }
+ }
+
+ private:
+ int64_t input_ne[4];
+ size_t input_stride[4];
+
+ int64_t scale_ne[4];
+ size_t scale_stride[4];
+
+ int64_t indices_ne[4];
+ size_t indices_stride[4];
+
+ int64_t output_ne[4];
+ size_t output_stride[4];
+
+ int64_t ir;
+ int64_t dr;
+
+ int64_t group_size_in_row;
+
+ TPipe pipe;
+ GlobalTensor<int8_t> input_gm;
+ GlobalTensor<half> scale_gm;
+ GlobalTensor<int32_t> indices_gm;
+ GlobalTensor<float> output_gm;
+ TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
+ TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
+ TQue<QuePosition::VECIN, BUFFER_NUM> cast_queue;
+};
+
+template <typename T>
+__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
+ auto gm_ptr = (__gm__ uint8_t *)gm;
+ auto ub_ptr = (uint8_t *)(ub);
+ for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
+ *ub_ptr = *gm_ptr;
+ }
+}
+
+extern "C" __global__ __aicore__ void ascendc_get_row_q8_0(
+ GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
+ GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm,
+ GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
+ int64_t input_ne_ub[4];
+ int64_t indices_ne_ub[4];
+ size_t indices_nb_ub[4];
+ int64_t output_ne_ub[4];
+ size_t output_nb_ub[4];
+
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
+ copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
+ copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
+ copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+ GET_ROW_Q8_0 op;
+ op.init(input_gm, indices_gm, output_gm, input_ne_ub, indices_ne_ub,
+ indices_nb_ub, output_ne_ub, output_nb_ub);
+ op.calculate();
+}
diff --git a/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp b/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp
new file mode 100644
index 00000000..8423b3f0
--- /dev/null
+++ b/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp
@@ -0,0 +1,208 @@
+#include "kernel_operator.h"
+
+using namespace AscendC;
+
+#define BUFFER_NUM 2
+#define QK8_0 32
+
+class QUANTIZE_F16_Q8_0 {
+ public:
+ __aicore__ inline QUANTIZE_F16_Q8_0() {}
+ __aicore__ inline void init(GM_ADDR input, GM_ADDR output,
+ int64_t *input_ne_ub, size_t *input_nb_ub,
+ int64_t *output_ne_ub) {
+ int64_t op_block_num = GetBlockNum();
+ int64_t op_block_idx = GetBlockIdx();
+
+ for (int i = 0; i < 4; i++) {
+ input_ne[i] = input_ne_ub[i];
+ input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
+
+ output_ne[i] = output_ne_ub[i];
+ }
+
+ output_stride[0] = 1;
+ for (int i = 1; i < 4; i++) {
+ output_stride[i] = output_stride[i - 1] * output_ne[i - 1];
+ }
+
+ scale_ne = input_ne;
+ scale_stride[0] = 1;
+ scale_stride[1] = input_ne[0] / QK8_0;
+ for (int i = 2; i < 4; i++) {
+ scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
+ }
+
+ // split input tensor by rows.
+ uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3];
+ dr = nr / op_block_num;
+
+ uint64_t tails = nr % op_block_num;
+ if (op_block_idx < tails) {
+ dr += 1;
+ ir = dr * op_block_idx;
+ } else {
+ ir = dr * op_block_idx + tails;
+ }
+
+ group_size_in_row = scale_stride[1];
+ int64_t output_size = output_ne[0] * output_ne[1] * output_ne[2] *
+ output_ne[3] * sizeof(uint8_t);
+
+ input_gm.SetGlobalBuffer((__gm__ half *)input);
+ output_gm.SetGlobalBuffer((__gm__ int8_t *)output);
+ scale_gm.SetGlobalBuffer((__gm__ half *)(output + output_size + ir *
+ group_size_in_row *
+ sizeof(half)));
+
+ pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(half));
+ pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t));
+ pipe.InitBuffer(work_queue, 1, 32);
+ pipe.InitBuffer(max_queue, 1, 32);
+ pipe.InitBuffer(abs_queue, 1, QK8_0 * sizeof(float));
+ pipe.InitBuffer(scale_queue, 1, 32);
+ pipe.InitBuffer(cast_queue ,1 ,QK8_0 * sizeof(float));
+ }
+
+ __aicore__ inline void copy_in(uint32_t offset) {
+ LocalTensor<half> input_local = input_queue.AllocTensor<half>();
+ DataCopy(input_local, input_gm[offset], QK8_0);
+ input_queue.EnQue(input_local);
+ }
+
+ __aicore__ inline void copy_out(uint32_t offset) {
+ LocalTensor<int8_t> output_local = output_queue.DeQue<int8_t>();
+ DataCopy(output_gm[offset], output_local, QK8_0);
+ output_queue.FreeTensor(output_local);
+ }
+
+ __aicore__ inline half calculate_group(int64_t row, int64_t group) {
+ const int64_t i3 = row / (input_ne[1] * input_ne[2]);
+ const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1];
+ const int64_t i1 =
+ row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1];
+
+ const int64_t input_offset = i1 * input_stride[1] +
+ i2 * input_stride[2] +
+ i3 * input_stride[3] + QK8_0 * group;
+
+ const int64_t output_offset = i1 * output_stride[1] +
+ i2 * output_stride[2] +
+ i3 * output_stride[3] + QK8_0 * group;
+
+ copy_in(input_offset);
+ LocalTensor<half> input_local = input_queue.DeQue<half>();
+ LocalTensor<int8_t> output_local = output_queue.AllocTensor<int8_t>();
+ LocalTensor<float> work_local = work_queue.AllocTensor<float>();
+ LocalTensor<float> abs_local = abs_queue.AllocTensor<float>();
+ LocalTensor<float> max_local = max_queue.AllocTensor<float>();
+ LocalTensor<float> cast_local = cast_queue.AllocTensor<float>();
+
+ Cast(cast_local, input_local, RoundMode::CAST_NONE, QK8_0);
+ Abs(abs_local, cast_local, QK8_0);
+ ReduceMax(max_local, abs_local, work_local, QK8_0);
+
+ pipe_barrier(PIPE_ALL);
+ float d = max_local.GetValue(0);
+ d = d / ((1 << 7) - 1);
+ if (d != 0) {
+ Muls(cast_local, cast_local, 1.0f / d, QK8_0);
+ }
+
+ Cast(cast_local, cast_local, RoundMode::CAST_ROUND, QK8_0);
+ Cast(input_local, cast_local, RoundMode::CAST_ROUND, QK8_0);
+ Cast(output_local, input_local, RoundMode::CAST_ROUND, QK8_0);
+ output_queue.EnQue(output_local);
+ copy_out(output_offset);
+
+ input_queue.FreeTensor(input_local);
+ work_queue.FreeTensor(work_local);
+ abs_queue.FreeTensor(abs_local);
+ max_queue.FreeTensor(max_local);
+ cast_queue.FreeTensor(cast_local);
+ return (half)d;
+ }
+
+ __aicore__ inline void calculate() {
+ LocalTensor<half> scale_local = scale_queue.AllocTensor<half>();
+ uint32_t scale_local_offset = 0;
+ uint32_t scale_global_offset = 0;
+ for (int64_t i = ir; i < ir + dr; i++) {
+ for (int64_t j = 0; j < group_size_in_row; j++) {
+ half scale = calculate_group(i, j);
+ scale_local.SetValue(scale_local_offset++, scale);
+ if (scale_local_offset == 16) {
+ scale_local_offset = 0;
+ // TODO: OPTIMIZE ME
+ pipe_barrier(PIPE_ALL);
+ DataCopy(scale_gm[scale_global_offset], scale_local, 16);
+ pipe_barrier(PIPE_ALL);
+ scale_global_offset += 16;
+ }
+ }
+ }
+
+ if (scale_local_offset != 0) {
+ pipe_barrier(PIPE_ALL);
+ DataCopyExtParams dataCopyParams;
+ dataCopyParams.blockCount = 1;
+ dataCopyParams.blockLen = scale_local_offset * sizeof(half);
+ DataCopyPad(scale_gm[scale_global_offset], scale_local,
+ dataCopyParams);
+ pipe_barrier(PIPE_ALL);
+ }
+ }
+
+ private:
+ int64_t input_ne[4];
+ size_t input_stride[4];
+
+ int64_t *scale_ne;
+ size_t scale_stride[4];
+
+ int64_t output_ne[4];
+ size_t output_stride[4];
+
+ int64_t group_size_in_row;
+
+ int64_t ir;
+ int64_t dr;
+
+ TPipe pipe;
+ GlobalTensor<half> input_gm;
+ GlobalTensor<half> scale_gm;
+ GlobalTensor<int8_t> output_gm;
+ TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
+ TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
+ TQue<QuePosition::VECIN, 1> work_queue;
+ TQue<QuePosition::VECOUT, 1> max_queue;
+ TQue<QuePosition::VECIN, 1> abs_queue;
+ TQue<QuePosition::VECOUT, 1> scale_queue;
+ TQue<QuePosition::VECOUT, 1> cast_queue;
+
+};
+
+template <typename T>
+__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
+ auto gm_ptr = (__gm__ uint8_t *)gm;
+ auto ub_ptr = (uint8_t *)(ub);
+ for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
+ *ub_ptr = *gm_ptr;
+ }
+}
+
+extern "C" __global__ __aicore__ void ascendc_quantize_f16_q8_0(
+ GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
+ GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
+ int64_t input_ne_ub[4];
+ size_t input_nb_ub[4];
+ int64_t output_ne_ub[4];
+
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
+
+ QUANTIZE_F16_Q8_0 op;
+ op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
+ op.calculate();
+}
diff --git a/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp b/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp
new file mode 100644
index 00000000..b7c57509
--- /dev/null
+++ b/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp
@@ -0,0 +1,206 @@
+#include "kernel_operator.h"
+
+using namespace AscendC;
+
+#define BUFFER_NUM 2
+#define QK8_0 32
+
+class QUANTIZE_F32_Q8_0 {
+ public:
+ __aicore__ inline QUANTIZE_F32_Q8_0() {}
+ __aicore__ inline void init(GM_ADDR input, GM_ADDR output,
+ int64_t *input_ne_ub, size_t *input_nb_ub,
+ int64_t *output_ne_ub) {
+ int64_t op_block_num = GetBlockNum();
+ int64_t op_block_idx = GetBlockIdx();
+
+ for (int i = 0; i < 4; i++) {
+ input_ne[i] = input_ne_ub[i];
+ input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
+
+ output_ne[i] = output_ne_ub[i];
+ }
+
+ output_stride[0] = 1;
+ for (int i = 1; i < 4; i++) {
+ output_stride[i] = output_stride[i - 1] * output_ne[i - 1];
+ }
+
+ scale_ne = input_ne;
+ scale_stride[0] = 1;
+ scale_stride[1] = input_ne[0] / QK8_0;
+ for (int i = 2; i < 4; i++) {
+ scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
+ }
+
+ // split input tensor by rows.
+ uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3];
+ dr = nr / op_block_num;
+
+ uint64_t tails = nr % op_block_num;
+ if (op_block_idx < tails) {
+ dr += 1;
+ ir = dr * op_block_idx;
+ } else {
+ ir = dr * op_block_idx + tails;
+ }
+
+ group_size_in_row = scale_stride[1];
+ int64_t output_size = output_ne[0] * output_ne[1] * output_ne[2] *
+ output_ne[3] * sizeof(uint8_t);
+
+ input_gm.SetGlobalBuffer((__gm__ float *)input);
+ output_gm.SetGlobalBuffer((__gm__ int8_t *)output);
+ scale_gm.SetGlobalBuffer((__gm__ half *)(output + output_size +
+ ir * group_size_in_row *
+ sizeof(half)));
+
+ pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(float));
+ pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t));
+ pipe.InitBuffer(work_queue, 1, 32);
+ pipe.InitBuffer(max_queue, 1, 32);
+ pipe.InitBuffer(abs_queue, 1, QK8_0 * sizeof(float));
+ pipe.InitBuffer(cast_queue, 1, QK8_0 * sizeof(half));
+ pipe.InitBuffer(scale_queue, 1, 32);
+ }
+
+ __aicore__ inline void copy_in(uint32_t offset) {
+ LocalTensor<float> input_local = input_queue.AllocTensor<float>();
+ DataCopy(input_local, input_gm[offset], QK8_0);
+ input_queue.EnQue(input_local);
+ }
+
+ __aicore__ inline void copy_out(uint32_t offset) {
+ LocalTensor<int8_t> output_local = output_queue.DeQue<int8_t>();
+ DataCopy(output_gm[offset], output_local, QK8_0);
+ output_queue.FreeTensor(output_local);
+ }
+
+ __aicore__ inline half calculate_group(int64_t row, int64_t group) {
+ const int64_t i3 = row / (input_ne[1] * input_ne[2]);
+ const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1];
+ const int64_t i1 =
+ row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1];
+
+ const int64_t input_offset = i1 * input_stride[1] +
+ i2 * input_stride[2] +
+ i3 * input_stride[3] + QK8_0 * group;
+
+ const int64_t output_offset = i1 * output_stride[1] +
+ i2 * output_stride[2] +
+ i3 * output_stride[3] + QK8_0 * group;
+
+ copy_in(input_offset);
+ LocalTensor<float> input_local = input_queue.DeQue<float>();
+ LocalTensor<int8_t> output_local = output_queue.AllocTensor<int8_t>();
+ LocalTensor<float> work_local = work_queue.AllocTensor<float>();
+ LocalTensor<float> abs_local = abs_queue.AllocTensor<float>();
+ LocalTensor<float> max_local = max_queue.AllocTensor<float>();
+ LocalTensor<half> cast_local = cast_queue.AllocTensor<half>();
+
+ Abs(abs_local, input_local, QK8_0);
+ ReduceMax(max_local, abs_local, work_local, QK8_0);
+ pipe_barrier(PIPE_ALL);
+ float d = max_local.GetValue(0);
+ d = d / ((1 << 7) - 1);
+ if (d != 0) {
+ Muls(input_local, input_local, 1.0f / d, QK8_0);
+ }
+
+ Cast(input_local, input_local, RoundMode::CAST_ROUND, QK8_0);
+ Cast(cast_local, input_local, RoundMode::CAST_ROUND, QK8_0);
+ Cast(output_local, cast_local, RoundMode::CAST_ROUND, QK8_0);
+ output_queue.EnQue(output_local);
+ copy_out(output_offset);
+
+ input_queue.FreeTensor(input_local);
+ work_queue.FreeTensor(work_local);
+ abs_queue.FreeTensor(abs_local);
+ max_queue.FreeTensor(max_local);
+ cast_queue.FreeTensor(cast_local);
+
+ return (half)d;
+ }
+
+ __aicore__ inline void calculate() {
+ LocalTensor<half> scale_local = scale_queue.AllocTensor<half>();
+ uint32_t scale_local_offset = 0;
+ uint32_t scale_global_offset = 0;
+ for (int64_t i = ir; i < ir + dr; i++) {
+ for (int64_t j = 0; j < group_size_in_row; j++) {
+ half scale = calculate_group(i, j);
+ scale_local.SetValue(scale_local_offset++, scale);
+ if (scale_local_offset == 16) {
+ scale_local_offset = 0;
+ // TODO: OPTIMIZE ME
+ pipe_barrier(PIPE_ALL);
+ DataCopy(scale_gm[scale_global_offset], scale_local, 16);
+ pipe_barrier(PIPE_ALL);
+ scale_global_offset += 16;
+ }
+ }
+ }
+
+ if (scale_local_offset != 0) {
+ pipe_barrier(PIPE_ALL);
+ DataCopyExtParams dataCopyParams;
+ dataCopyParams.blockCount = 1;
+ dataCopyParams.blockLen = scale_local_offset * sizeof(half);
+ DataCopyPad(scale_gm[scale_global_offset], scale_local,
+ dataCopyParams);
+ pipe_barrier(PIPE_ALL);
+ }
+ }
+
+ private:
+ int64_t input_ne[4];
+ size_t input_stride[4];
+
+ int64_t *scale_ne;
+ size_t scale_stride[4];
+
+ int64_t output_ne[4];
+ size_t output_stride[4];
+
+ int64_t group_size_in_row;
+
+ int64_t ir;
+ int64_t dr;
+
+ TPipe pipe;
+ GlobalTensor<float> input_gm;
+ GlobalTensor<half> scale_gm;
+ GlobalTensor<int8_t> output_gm;
+ TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
+ TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
+ TQue<QuePosition::VECIN, 1> work_queue;
+ TQue<QuePosition::VECOUT, 1> max_queue;
+ TQue<QuePosition::VECIN, 1> abs_queue;
+ TQue<QuePosition::VECIN, 1> cast_queue;
+ TQue<QuePosition::VECOUT, 1> scale_queue;
+};
+
+template <typename T>
+__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
+ auto gm_ptr = (__gm__ uint8_t *)gm;
+ auto ub_ptr = (uint8_t *)(ub);
+ for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
+ *ub_ptr = *gm_ptr;
+ }
+}
+
+extern "C" __global__ __aicore__ void ascendc_quantize_f32_q8_0(
+ GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
+ GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
+ int64_t input_ne_ub[4];
+ size_t input_nb_ub[4];
+ int64_t output_ne_ub[4];
+
+ copy_to_ub(input_ne_gm, input_ne_ub, 32);
+ copy_to_ub(input_nb_gm, input_nb_ub, 32);
+ copy_to_ub(output_ne_gm, output_ne_ub, 32);
+
+ QUANTIZE_F32_Q8_0 op;
+ op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
+ op.calculate();
+}
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
new file mode 100644
index 00000000..da3f1b3c
--- /dev/null
+++ b/ggml/src/ggml-common.h
@@ -0,0 +1,1880 @@
+//
+// Copyright (C) 2023-2024 The ggml authors
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+
+#ifndef GGML_COMMON_DECL
+
+#if defined(GGML_COMMON_DECL_C)
+#include <stdint.h>
+
+typedef uint16_t ggml_half;
+typedef uint32_t ggml_half2;
+
+#define GGML_COMMON_AGGR
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_METAL)
+#include <metal_stdlib>
+
+typedef half ggml_half;
+typedef half2 ggml_half2;
+
+#define GGML_COMMON_AGGR
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_CUDA)
+#include <cuda_fp16.h>
+#include <cstdint>
+
+typedef half ggml_half;
+typedef half2 ggml_half2;
+
+#define GGML_COMMON_AGGR data
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_HIP)
+#include <hip/hip_fp16.h>
+#include <cstdint>
+
+typedef half ggml_half;
+typedef half2 ggml_half2;
+
+#define GGML_COMMON_AGGR data
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_SYCL)
+#include <sycl/half_type.hpp>
+#include <cstdint>
+
+typedef sycl::half ggml_half;
+typedef sycl::half2 ggml_half2;
+
+#define GGML_COMMON_AGGR data
+
+#define GGML_COMMON_DECL
+#endif
+
+#if defined(GGML_COMMON_DECL)
+
+#ifndef __cplusplus
+#ifndef static_assert
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
+#define static_assert(cond, msg) _Static_assert(cond, msg)
+#else
+#define static_assert(cond, msg) struct global_scope_noop_trick
+#endif
+#endif
+#endif // __cplusplus
+
+// QK = number of values after dequantization
+// QK_K = super-block size
+
+#define QK_K 256
+#define K_SCALE_SIZE 12
+
+#if defined(GGML_COMMON_DECL_CUDA) || defined(GGML_COMMON_DECL_HIP) || defined(GGML_COMMON_DECL_SYCL)
+// QR = QK / number of values before dequantization
+// QI = number of 32 bit integers before dequantization
+
+#define QI4_0 (QK4_0 / (4 * QR4_0))
+#define QR4_0 2
+
+#define QI4_1 (QK4_1 / (4 * QR4_1))
+#define QR4_1 2
+
+#define QI5_0 (QK5_0 / (4 * QR5_0))
+#define QR5_0 2
+
+#define QI5_1 (QK5_1 / (4 * QR5_1))
+#define QR5_1 2
+
+#define QI8_0 (QK8_0 / (4 * QR8_0))
+#define QR8_0 1
+
+#define QI8_1 (QK8_1 / (4 * QR8_1))
+#define QR8_1 1
+
+#define QI2_K (QK_K / (4*QR2_K))
+#define QR2_K 4
+
+#define QI3_K (QK_K / (4*QR3_K))
+#define QR3_K 4
+
+#define QI4_K (QK_K / (4*QR4_K))
+#define QR4_K 2
+
+#define QI5_K (QK_K / (4*QR5_K))
+#define QR5_K 2
+
+#define QI6_K (QK_K / (4*QR6_K))
+#define QR6_K 2
+
+#define QI2_XXS (QK_K / (4*QR2_XXS))
+#define QR2_XXS 4
+
+#define QI2_XS (QK_K / (4*QR2_XS))
+#define QR2_XS 4
+
+#define QI2_S (QK_K / (4*QR2_S))
+#define QR2_S 4
+
+#define QI3_XXS (QK_K / (4*QR3_XXS))
+#define QR3_XXS 4
+
+#define QI3_XS (QK_K / (4*QR3_XS))
+#define QR3_XS 4
+
+#define QI1_S (QK_K / (4*QR1_S))
+#define QR1_S 8
+
+#define QI1_M (QK_K / (4*QR1_M))
+#define QR1_M 8
+
+#define QI4_NL (QK4_NL / (4*QR4_NL))
+#define QR4_NL 2
+
+#define QI4_XS (QK_K / (4*QR4_XS))
+#define QR4_XS 2
+
+#define QI3_S (QK_K / (4*QR3_S))
+#define QR3_S 4
+
+#define QI1_BN (QK_IQ1BN / (4*QR1_BN))
+#define QR1_BN 8
+
+#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
+
+#define QK4_0 32
+typedef struct {
+ ggml_half d; // delta
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, "wrong q4_0 block size/padding");
+
+#define QK4_1 32
+typedef struct {
+ union {
+ struct {
+ ggml_half d; // delta
+ ggml_half m; // min
+ } GGML_COMMON_AGGR;
+ ggml_half2 dm;
+ };
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
+} block_q4_1;
+static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
+
+#define QK5_0 32
+typedef struct {
+ ggml_half d; // delta
+ uint8_t qh[4]; // 5-th bit of quants
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
+
+#define QK5_1 32
+typedef struct {
+ union {
+ struct {
+ ggml_half d; // delta
+ ggml_half m; // min
+ } GGML_COMMON_AGGR;
+ ggml_half2 dm;
+ };
+ uint8_t qh[4]; // 5-th bit of quants
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
+} block_q5_1;
+static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_half) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
+
+#define QK8_0 32
+typedef struct {
+ ggml_half d; // delta
+ int8_t qs[QK8_0]; // quants
+} block_q8_0;
+static_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, "wrong q8_0 block size/padding");
+
+#define QK8_1 32
+typedef struct {
+ union {
+ struct {
+ ggml_half d; // delta
+ ggml_half s; // d * sum(qs[i])
+ } GGML_COMMON_AGGR;
+ ggml_half2 ds;
+ };
+ int8_t qs[QK8_1]; // quants
+} block_q8_1;
+static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding");
+
+typedef struct {
+ ggml_half d[8];
+ int8_t qs[4*QK8_1];
+} block_q8_1_x4;
+static_assert(sizeof(block_q8_1_x4) == 4*sizeof(block_q8_1), "wrong q8_1_x4 block size/padding");
+typedef struct {
+ ggml_half d[4];
+ int8_t qs[4*QK8_0];
+} block_q8_0_x4;
+static_assert(sizeof(block_q8_0_x4) == 4*sizeof(block_q8_0), "wrong q8_0_x4 block size/padding");
+
+typedef struct {
+ ggml_half d[4]; // deltas for 4 q4_0 blocks
+ uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks
+} block_q4_0x4;
+static_assert(sizeof(block_q4_0x4) == 4 * sizeof(ggml_half) + QK4_0 * 2, "wrong q4_0x4 block size/padding");
+
+typedef struct {
+ ggml_half d[8]; // deltas for 8 q4_0 blocks
+ uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks
+} block_q4_0x8;
+static_assert(sizeof(block_q4_0x8) == 8 * sizeof(ggml_half) + QK4_0 * 4, "wrong q4_0x8 block size/padding");
+
+typedef struct {
+ ggml_half d[4]; // deltas for 4 q8_0 blocks
+ int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks
+} block_q8_0x4;
+static_assert(sizeof(block_q8_0x4) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong q8_0x4 block size/padding");
+
+typedef struct {
+ ggml_half d[8]; // deltas for 8 q8_0 blocks
+ int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks
+} block_q8_0x8;
+static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding");
+
+//
+// Super-block quantization structures
+//
+
+// 2-bit quantization
+// weight is represented as x = a * q + b
+// 16 blocks of 16 elements each
+// Effectively 2.625 bits per weight
+typedef struct {
+ uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
+ uint8_t qs[QK_K/4]; // quants
+ union {
+ struct {
+ ggml_half d; // super-block scale for quantized scales
+ ggml_half dmin; // super-block scale for quantized mins
+ } GGML_COMMON_AGGR;
+ ggml_half2 dm;
+ };
+} block_q2_K;
+static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
+
+// 3-bit quantization
+// weight is represented as x = a * q
+// 16 blocks of 16 elements each
+// Effectively 3.4375 bits per weight
+typedef struct {
+ uint8_t hmask[QK_K/8]; // quants - high bit
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
+ uint8_t scales[12]; // scales, quantized with 6 bits
+ ggml_half d; // super-block scale
+} block_q3_K;
+static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
+
+// 4-bit quantization
+// 8 blocks of 32 elements each
+// weight is represented as x = a * q + b
+// Effectively 4.5 bits per weight
+typedef struct {
+ union {
+ struct {
+ ggml_half d; // super-block scale for quantized scales
+ ggml_half dmin; // super-block scale for quantized mins
+ } GGML_COMMON_AGGR;
+ ggml_half2 dm;
+ };
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+ uint8_t qs[QK_K/2]; // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
+
+// 5-bit quantization
+// 8 blocks of 32 elements each
+// weight is represented as x = a * q + b
+// Effectively 5.5 bits per weight
+typedef struct {
+ union {
+ struct {
+ ggml_half d; // super-block scale for quantized scales
+ ggml_half dmin; // super-block scale for quantized mins
+ } GGML_COMMON_AGGR;
+ ggml_half2 dm;
+ };
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+ uint8_t qh[QK_K/8]; // quants, high bit
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+
+// 6-bit quantization
+// weight is represented as x = a * q
+// 16 blocks of 16 elements each
+// Effectively 6.5625 bits per weight
+typedef struct {
+ uint8_t ql[QK_K/2]; // quants, lower 4 bits
+ uint8_t qh[QK_K/4]; // quants, upper 2 bits
+ int8_t scales[QK_K/16]; // scales, quantized with 8 bits
+ ggml_half d; // super-block scale
+} block_q6_K;
+static_assert(sizeof(block_q6_K) == sizeof(ggml_half) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
+
+// This is only used for intermediate quantization and dot products
+typedef struct {
+ float d; // delta
+ int8_t qs[QK_K]; // quants
+ int16_t bsums[QK_K/16]; // sum of quants in groups of 16
+} block_q8_K;
+static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
+typedef struct {
+ float d; // delta
+ int8_t qs[64]; // quants
+} block_q8_K64;
+static_assert(sizeof(block_q8_K64) == sizeof(float) + 64, "wrong q8_K64 block size/padding");
+typedef struct {
+ float d; // delta
+ int8_t qs[128]; // quants
+} block_q8_K128;
+static_assert(sizeof(block_q8_K128) == sizeof(float) + 128, "wrong q8_K128 block size/padding");
+
+// (Almost) "true" 2-bit quantization.
+// Due to the need to use blocks as per ggml design, it ends up using
+// 2.0625 bpw because of the 16-bit scale for each block of 256.
+typedef struct {
+ ggml_half d;
+ uint16_t qs[QK_K/8];
+} block_iq2_xxs;
+static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
+
+// 2.3125 bpw quants
+typedef struct {
+ ggml_half d;
+ uint16_t qs[QK_K/8];
+ uint8_t scales[QK_K/32];
+} block_iq2_xs;
+static_assert(sizeof(block_iq2_xs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
+
+// 2.5625 bpw quants
+typedef struct {
+ ggml_half d;
+ uint8_t qs[QK_K/4];
+ uint8_t qh[QK_K/32];
+ uint8_t scales[QK_K/32];
+} block_iq2_s;
+static_assert(sizeof(block_iq2_s) == sizeof(ggml_half) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding");
+
+// (Almost) "true" 3-bit quantization.
+// Due to the need to use blocks as per ggml design, it ends up using
+// 3.0625 bpw because of the 16-bit scale for each block of 256.
+typedef struct {
+ ggml_half d;
+ uint8_t qs[3*QK_K/8];
+} block_iq3_xxs;
+static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_half) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
+
+// 3.4375 bpw
+#define IQ3S_N_SCALE QK_K/64
+typedef struct {
+ ggml_half d;
+ uint8_t qs[QK_K/4];
+ uint8_t qh[QK_K/32];
+ uint8_t signs[QK_K/8];
+ uint8_t scales[IQ3S_N_SCALE];
+} block_iq3_s;
+static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
+
+typedef struct {
+ ggml_half d;
+ uint8_t qs[QK_K/8];
+ uint16_t qh[QK_K/32];
+} block_iq1_s;
+static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
+
+// 1.75 bpw
+typedef struct {
+ uint8_t qs[QK_K/8]; // grid index, low 8 bits
+ uint8_t qh[QK_K/16]; // grid index, high 3 bits + grid shift bit (for two groups of 8)
+ uint8_t scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64)
+} block_iq1_m;
+static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");
+
+//
+// Bitnet - implemented as 1.75 bpw
+// The block scale is a waste, but it allows us to plug it in without any additional
+// changes to ggml.
+//
+#define QK_IQ1BN 64
+typedef struct {
+ uint8_t ql[12];
+ uint8_t extra;
+} block_iq1_bn;
+static_assert(sizeof(block_iq1_bn) == 13, "wrong iq1_bn block size/padding");
+//
+// Bitnet - implemented as 2.25 bpw
+//
+#define QK_IQ2BN 64
+typedef struct {
+ uint8_t qs[QK_IQ2BN/4];
+} block_iq2_bn;
+static_assert(sizeof(block_iq2_bn) == QK_IQ2BN/4, "wrong iq2_bn block size/padding");
+
+// Used by IQ1_M quants
+typedef union {
+ ggml_half f16;
+ uint16_t u16;
+} iq1m_scale_t;
+
+// Non-linear quants
+#define QK4_NL 32
+typedef struct {
+ ggml_half d;
+ uint8_t qs[QK4_NL/2];
+} block_iq4_nl;
+static_assert(sizeof(block_iq4_nl) == sizeof(ggml_half) + QK4_NL/2, "wrong iq4_nl block size/padding");
+
+typedef struct {
+ ggml_half d;
+ uint16_t scales_h;
+ uint8_t scales_l[QK_K/64];
+ uint8_t qs[QK_K/2];
+} block_iq4_xs;
+static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
+
+#endif // GGML_COMMON_DECL
+#endif // GGML_COMMON_DECL
+
+////////////////////////////////////////////////////////////////////////////////
+
+#ifndef GGML_COMMON_IMPL
+
+#if defined(GGML_COMMON_IMPL_C)
+#include <stdint.h>
+
+#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#elif defined(GGML_COMMON_IMPL_METAL)
+#include <metal_stdlib>
+
+#define GGML_TABLE_BEGIN(type, name, size) static const constant type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP)
+#include <cstdint>
+
+#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#elif defined(GGML_COMMON_IMPL_SYCL)
+
+#include <cstdint>
+
+#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#endif
+
+#if defined(GGML_COMMON_IMPL)
+
+GGML_TABLE_BEGIN(uint8_t, kmask_iq2xs, 8)
+ 1, 2, 4, 8, 16, 32, 64, 128
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
+GGML_TABLE_END()
+
+//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
+ 0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
+ 0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
+ 0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff,
+ 0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff,
+ 0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff,
+ 0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff,
+ 0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff,
+ 0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff,
+ 0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff,
+ 0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff,
+ 0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff,
+ 0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff,
+ 0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff,
+ 0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff,
+ 0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff,
+ 0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff,
+ 0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff,
+ 0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff,
+ 0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff,
+ 0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff,
+ 0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff,
+ 0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff,
+ 0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff,
+ 0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff,
+ 0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff,
+ 0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff,
+ 0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff,
+ 0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff,
+ 0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff,
+ 0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff,
+ 0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
+ 0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
+GGML_TABLE_END()
+//#endif
+
+
+GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint64_t, iq2xs_grid, 512)
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
+ 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
+ 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
+ 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
+ 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
+ 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
+ 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
+ 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
+ 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
+ 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
+ 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
+ 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
+ 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
+ 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
+ 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
+ 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
+ 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
+ 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
+ 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
+ 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
+ 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
+ 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
+ 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
+ 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
+ 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
+ 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
+ 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
+ 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
+ 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
+ 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
+ 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
+ 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
+ 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
+ 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
+ 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
+ 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
+ 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
+ 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
+ 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
+ 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
+ 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
+ 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
+ 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
+ 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
+ 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
+ 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
+ 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
+ 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
+ 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
+ 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
+ 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
+ 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
+ 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
+ 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
+ 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
+ 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
+ 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
+ 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
+ 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
+ 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
+ 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
+ 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
+ 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
+ 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
+ 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
+ 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
+ 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
+ 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
+ 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
+ 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
+ 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
+ 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
+ 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
+ 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
+ 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
+ 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
+ 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
+ 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
+ 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
+ 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
+ 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
+ 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
+ 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
+ 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
+ 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
+ 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
+ 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
+ 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
+ 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
+ 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
+ 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
+ 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
+ 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
+ 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
+ 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
+ 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
+ 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
+ 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
+ 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
+ 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
+ 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
+ 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
+ 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
+ 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
+ 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
+ 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
+ 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
+ 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
+ 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
+ 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
+ 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
+ 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
+ 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
+ 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
+ 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
+ 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
+ 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
+ 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
+ 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint64_t, iq2s_grid, 1024)
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,
+ 0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,
+ 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,
+ 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,
+ 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
+ 0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,
+ 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,
+ 0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,
+ 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,
+ 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
+ 0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,
+ 0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,
+ 0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,
+ 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,
+ 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
+ 0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,
+ 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,
+ 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,
+ 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,
+ 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
+ 0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,
+ 0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,
+ 0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,
+ 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,
+ 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
+ 0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,
+ 0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,
+ 0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,
+ 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,
+ 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
+ 0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,
+ 0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,
+ 0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,
+ 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,
+ 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,
+ 0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,
+ 0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,
+ 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,
+ 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
+ 0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,
+ 0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,
+ 0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,
+ 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,
+ 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
+ 0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,
+ 0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,
+ 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,
+ 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,
+ 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
+ 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,
+ 0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,
+ 0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,
+ 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,
+ 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
+ 0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,
+ 0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,
+ 0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,
+ 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,
+ 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
+ 0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,
+ 0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,
+ 0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,
+ 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,
+ 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
+ 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,
+ 0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,
+ 0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,
+ 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,
+ 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
+ 0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,
+ 0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,
+ 0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,
+ 0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,
+ 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
+ 0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,
+ 0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,
+ 0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,
+ 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,
+ 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
+ 0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,
+ 0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,
+ 0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,
+ 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,
+ 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
+ 0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,
+ 0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,
+ 0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,
+ 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
+ 0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,
+ 0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,
+ 0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,
+ 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,
+ 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
+ 0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,
+ 0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,
+ 0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,
+ 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,
+ 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
+ 0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,
+ 0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,
+ 0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,
+ 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,
+ 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
+ 0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,
+ 0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,
+ 0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,
+ 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,
+ 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
+ 0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,
+ 0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,
+ 0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
+ 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,
+ 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
+ 0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,
+ 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,
+ 0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,
+ 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,
+ 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
+ 0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,
+ 0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,
+ 0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,
+ 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,
+ 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
+ 0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,
+ 0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,
+ 0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,
+ 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,
+ 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
+ 0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,
+ 0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,
+ 0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,
+ 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,
+ 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
+ 0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,
+ 0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,
+ 0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,
+ 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,
+ 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
+ 0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,
+ 0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,
+ 0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,
+ 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,
+ 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
+ 0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,
+ 0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,
+ 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,
+ 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,
+ 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
+ 0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,
+ 0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,
+ 0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,
+ 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,
+ 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
+ 0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,
+ 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,
+ 0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,
+ 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,
+ 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
+ 0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,
+ 0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,
+ 0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,
+ 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,
+ 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
+ 0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,
+ 0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,
+ 0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,
+ 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,
+ 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
+ 0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,
+ 0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,
+ 0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,
+ 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,
+ 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
+ 0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,
+ 0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,
+ 0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,
+ 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,
+ 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
+ 0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,
+ 0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,
+ 0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,
+ 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,
+ 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
+ 0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,
+ 0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,
+ 0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,
+ 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,
+ 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
+ 0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,
+ 0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,
+ 0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,
+ 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,
+ 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
+ 0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,
+ 0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,
+ 0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,
+ 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,
+ 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
+ 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
+ 0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,
+ 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,
+ 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
+ 0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,
+ 0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,
+ 0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,
+ 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,
+ 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
+ 0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,
+ 0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,
+ 0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,
+ 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,
+ 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
+ 0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,
+ 0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,
+ 0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,
+ 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,
+ 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
+ 0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,
+ 0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,
+ 0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,
+ 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,
+ 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
+ 0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,
+ 0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,
+ 0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,
+ 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,
+ 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
+ 0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,
+ 0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,
+ 0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,
+ 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,
+ 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
+ 0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,
+ 0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,
+ 0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,
+ 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,
+ 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
+ 0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,
+ 0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,
+ 0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,
+ 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,
+ 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
+ 0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,
+ 0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,
+ 0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,
+ 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,
+ 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
+ 0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint32_t, iq3xxs_grid, 256)
+ 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
+ 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
+ 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
+ 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
+ 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
+ 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
+ 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
+ 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
+ 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
+ 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
+ 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
+ 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
+ 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
+ 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
+ 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
+ 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
+ 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
+ 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
+ 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
+ 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
+ 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
+ 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
+ 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
+ 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
+ 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
+ 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
+ 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
+ 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
+ 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
+ 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
+ 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
+ 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
+ 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
+ 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
+ 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
+ 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
+ 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
+ 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
+ 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
+ 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
+ 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
+ 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
+ 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
+ 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
+ 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
+ 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
+ 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
+ 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
+ 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
+ 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
+ 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
+ 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
+ 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
+ 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
+ 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
+ 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
+ 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
+ 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
+ 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
+ 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
+ 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
+ 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
+ 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
+ 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
+ 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
+ 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
+ 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
+ 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
+ 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
+ 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
+ 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
+ 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
+ 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
+ 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
+ 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
+ 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
+ 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
+ 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
+ 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
+ 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
+ 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
+ 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
+ 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
+ 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
+ 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
+ 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
+ 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
+ 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
+ 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
+ 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
+ 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
+ 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
+ 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
+ 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
+ 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
+ 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
+GGML_TABLE_END()
+
+#define NGRID_IQ1S 2048
+#define IQ1S_DELTA 0.125f
+#define IQ1M_DELTA 0.125f
+#if defined(GGML_COMMON_IMPL_C)
+GGML_TABLE_BEGIN(uint64_t, iq1s_grid, NGRID_IQ1S)
+ 0xffffffffffffffff, 0xffffffffffffff01, 0xffffffffffff0000, 0xffffffffffff01ff,
+ 0xffffffffffff0101, 0xffffffffff00ff00, 0xffffffffff000000, 0xffffffffff01ffff,
+ 0xffffffffff01ff01, 0xffffffffff0101ff, 0xffffffffff010101, 0xffffffff00ff0000,
+ 0xffffffff0000ff00, 0xffffffff000000ff, 0xffffffff00000001, 0xffffffff00010000,
+ 0xffffffff01ffffff, 0xffffffff01ffff01, 0xffffffff01ff01ff, 0xffffffff01ff0101,
+ 0xffffffff01000000, 0xffffffff0101ffff, 0xffffffff0101ff01, 0xffffffff010101ff,
+ 0xffffffff01010101, 0xffffff00ffff00ff, 0xffffff00ffff0000, 0xffffff00ff00ff00,
+ 0xffffff00ff0000ff, 0xffffff00ff000001, 0xffffff00ff000100, 0xffffff00ff000101,
+ 0xffffff00ff010000, 0xffffff0000ffff00, 0xffffff0000ff0001, 0xffffff0000ff0100,
+ 0xffffff000000ff01, 0xffffff0000000000, 0xffffff0000000101, 0xffffff000001ff00,
+ 0xffffff00000100ff, 0xffffff0000010001, 0xffffff00000101ff, 0xffffff0001ff0000,
+ 0xffffff000100ff00, 0xffffff00010000ff, 0xffffff0001000001, 0xffffff0001010000,
+ 0xffffff01ffffffff, 0xffffff01ffffff01, 0xffffff01ffff01ff, 0xffffff01ffff0101,
+ 0xffffff01ff000000, 0xffffff01ff01ffff, 0xffffff01ff01ff01, 0xffffff01ff0101ff,
+ 0xffffff01ff010101, 0xffffff0100ff0000, 0xffffff010000ff00, 0xffffff0100000100,
+ 0xffffff01000100ff, 0xffffff0100010100, 0xffffff0101ffffff, 0xffffff0101ffff01,
+ 0xffffff0101ff01ff, 0xffffff0101ff0101, 0xffffff010100ff00, 0xffffff0101000000,
+ 0xffffff0101000100, 0xffffff010101ffff, 0xffffff010101ff01, 0xffffff01010101ff,
+ 0xffffff0101010101, 0xffff00ffff00ff00, 0xffff00ffff0000ff, 0xffff00ffff000001,
+ 0xffff00ffff010000, 0xffff00ff00ffff00, 0xffff00ff00ff0100, 0xffff00ff00000000,
+ 0xffff00ff00000101, 0xffff00ff000100ff, 0xffff00ff00010000, 0xffff00ff0100ff00,
+ 0xffff00ff01000100, 0xffff00ff01010000, 0xffff0000ffffff00, 0xffff0000ffff00ff,
+ 0xffff0000ffff0000, 0xffff0000ffff0001, 0xffff0000ff000000, 0xffff0000ff0001ff,
+ 0xffff0000ff000101, 0xffff0000ff010100, 0xffff000000ffffff, 0xffff000000ff0000,
+ 0xffff000000ff0101, 0xffff00000000ffff, 0xffff00000000ff00, 0xffff0000000000ff,
+ 0xffff000000000000, 0xffff000000000001, 0xffff000000000100, 0xffff00000001ffff,
+ 0xffff00000001ff01, 0xffff000000010000, 0xffff0000000101ff, 0xffff000000010101,
+ 0xffff000001ffff00, 0xffff00000100ff00, 0xffff000001000000, 0xffff0000010001ff,
+ 0xffff000001000101, 0xffff00000101ff00, 0xffff0000010100ff, 0xffff000001010000,
+ 0xffff000001010001, 0xffff000001010100, 0xffff0001ff0000ff, 0xffff0001ff000100,
+ 0xffff000100ffff00, 0xffff000100ff00ff, 0xffff00010000ffff, 0xffff00010000ff01,
+ 0xffff000100000000, 0xffff0001000001ff, 0xffff00010001ffff, 0xffff00010001ff00,
+ 0xffff000100010001, 0xffff000100010100, 0xffff000101ff0000, 0xffff00010100ff00,
+ 0xffff0001010000ff, 0xffff000101000100, 0xffff01ffffffffff, 0xffff01ffffffff01,
+ 0xffff01ffffff01ff, 0xffff01ffffff0101, 0xffff01ffff000000, 0xffff01ffff01ffff,
+ 0xffff01ffff01ff01, 0xffff01ffff0101ff, 0xffff01ffff010101, 0xffff01ff00ff0000,
+ 0xffff01ff0000ff00, 0xffff01ff00000001, 0xffff01ff00010000, 0xffff01ff01ffffff,
+ 0xffff01ff01ffff01, 0xffff01ff01ff01ff, 0xffff01ff01ff0101, 0xffff01ff01000000,
+ 0xffff01ff0101ffff, 0xffff01ff0101ff01, 0xffff01ff010101ff, 0xffff01ff01010101,
+ 0xffff0100ffff0000, 0xffff0100ff00ff00, 0xffff0100ff0000ff, 0xffff0100ff000100,
+ 0xffff0100ff0100ff, 0xffff0100ff010000, 0xffff010000ffff00, 0xffff01000000ffff,
+ 0xffff01000000ff00, 0xffff010000000000, 0xffff01000001ff00, 0xffff0100000100ff,
+ 0xffff010000010100, 0xffff01000100ff00, 0xffff0100010000ff, 0xffff010001000001,
+ 0xffff010001000100, 0xffff010001010000, 0xffff0101ffffffff, 0xffff0101ffffff01,
+ 0xffff0101ffff01ff, 0xffff0101ffff0101, 0xffff0101ff000000, 0xffff0101ff01ffff,
+ 0xffff0101ff01ff01, 0xffff0101ff0101ff, 0xffff0101ff010101, 0xffff010100ff0000,
+ 0xffff01010000ff00, 0xffff010100000100, 0xffff01010001ff00, 0xffff010100010000,
+ 0xffff010101ffffff, 0xffff010101ffff01, 0xffff010101ff0000, 0xffff010101ff01ff,
+ 0xffff010101ff0101, 0xffff010101000000, 0xffff01010101ffff, 0xffff01010101ff01,
+ 0xffff0101010101ff, 0xffff010101010101, 0xff00ffffff00ffff, 0xff00ffffff00ff00,
+ 0xff00ffffff0000ff, 0xff00ffffff000100, 0xff00ffffff0100ff, 0xff00ffffff010000,
+ 0xff00ffff00ffff00, 0xff00ffff00ff00ff, 0xff00ffff0000ffff, 0xff00ffff00000000,
+ 0xff00ffff000001ff, 0xff00ffff0001ff00, 0xff00ffff000100ff, 0xff00ffff00010000,
+ 0xff00ffff00010100, 0xff00ffff0100ff00, 0xff00ffff010000ff, 0xff00ffff01000001,
+ 0xff00ffff0101ff00, 0xff00ffff01010000, 0xff00ff00ffffff00, 0xff00ff00ffff00ff,
+ 0xff00ff00ffff0001, 0xff00ff00ffff0100, 0xff00ff00ff00ffff, 0xff00ff00ff00ff01,
+ 0xff00ff00ff000000, 0xff00ff00ff0001ff, 0xff00ff00ff01ff00, 0xff00ff00ff0100ff,
+ 0xff00ff00ff010100, 0xff00ff0000ff0000, 0xff00ff0000ff0101, 0xff00ff000000ffff,
+ 0xff00ff000000ff00, 0xff00ff000000ff01, 0xff00ff00000000ff, 0xff00ff0000000000,
+ 0xff00ff0000000001, 0xff00ff0000000100, 0xff00ff000001ffff, 0xff00ff0000010000,
+ 0xff00ff0001ff00ff, 0xff00ff000100ff01, 0xff00ff0001000000, 0xff00ff000101ff00,
+ 0xff00ff00010100ff, 0xff00ff01ff00ff00, 0xff00ff01ff0000ff, 0xff00ff01ff000001,
+ 0xff00ff01ff010000, 0xff00ff0100ffffff, 0xff00ff0100ff0001, 0xff00ff0100ff0100,
+ 0xff00ff010000ff01, 0xff00ff0100000000, 0xff00ff01000001ff, 0xff00ff0100000101,
+ 0xff00ff01000100ff, 0xff00ff0100010001, 0xff00ff0101ff0000, 0xff00ff010100ff00,
+ 0xff00ff01010000ff, 0xff00ff0101000001, 0xff00ff0101010000, 0xff0000ffffffff00,
+ 0xff0000ffffff0001, 0xff0000ffffff0100, 0xff0000ffff0000ff, 0xff0000ffff000000,
+ 0xff0000ffff0001ff, 0xff0000ffff000100, 0xff0000ffff01ff00, 0xff0000ffff010001,
+ 0xff0000ff00ffff00, 0xff0000ff00ff0000, 0xff0000ff00ff0001, 0xff0000ff00ff01ff,
+ 0xff0000ff00ff0101, 0xff0000ff0000ff00, 0xff0000ff000000ff, 0xff0000ff00000000,
+ 0xff0000ff00000001, 0xff0000ff00000100, 0xff0000ff0001ff01, 0xff0000ff00010000,
+ 0xff0000ff000101ff, 0xff0000ff01ff00ff, 0xff0000ff01ff0100, 0xff0000ff0100ffff,
+ 0xff0000ff010000ff, 0xff0000ff01000000, 0xff0000ff010001ff, 0xff0000ff01000100,
+ 0xff0000ff01000101, 0xff0000ff0101ff00, 0xff0000ff010100ff, 0xff0000ff01010000,
+ 0xff0000ff01010100, 0xff000000ffffff01, 0xff000000ffff0000, 0xff000000ffff0101,
+ 0xff000000ff00ff00, 0xff000000ff0000ff, 0xff000000ff000000, 0xff000000ff000001,
+ 0xff000000ff000100, 0xff000000ff01ffff, 0xff000000ff01ff01, 0xff000000ff010000,
+ 0xff000000ff0101ff, 0xff000000ff010101, 0xff00000000ffff00, 0xff00000000ff00ff,
+ 0xff00000000ff0000, 0xff00000000ff0001, 0xff0000000000ff00, 0xff0000000000ff01,
+ 0xff000000000000ff, 0xff00000000000000, 0xff00000000000001, 0xff00000000000100,
+ 0xff00000000000101, 0xff0000000001ff00, 0xff000000000100ff, 0xff00000000010000,
+ 0xff00000000010001, 0xff00000000010100, 0xff00000001ffffff, 0xff00000001ffff01,
+ 0xff00000001ff00ff, 0xff00000001ff0000, 0xff00000001ff01ff, 0xff00000001ff0101,
+ 0xff0000000100ffff, 0xff0000000100ff00, 0xff000000010000ff, 0xff00000001000000,
+ 0xff00000001000001, 0xff00000001000100, 0xff00000001000101, 0xff0000000101ffff,
+ 0xff0000000101ff01, 0xff00000001010000, 0xff000001ffffff00, 0xff000001ffff00ff,
+ 0xff000001ffff0000, 0xff000001ffff0001, 0xff000001ff000000, 0xff000001ff000001,
+ 0xff000001ff0001ff, 0xff000001ff000101, 0xff000001ff01ff00, 0xff000001ff010001,
+ 0xff00000100ffffff, 0xff00000100ffff01, 0xff00000100ff00ff, 0xff00000100ff0000,
+ 0xff00000100ff01ff, 0xff00000100ff0101, 0xff0000010000ff00, 0xff00000100000000,
+ 0xff00000100000001, 0xff000001000001ff, 0xff00000100000100, 0xff0000010001ff00,
+ 0xff000001000100ff, 0xff00000100010000, 0xff000001000101ff, 0xff00000100010100,
+ 0xff00000100010101, 0xff00000101ff0001, 0xff00000101ff0101, 0xff0000010100ff01,
+ 0xff00000101000000, 0xff000001010100ff, 0xff00000101010100, 0xff0001ffff00ff00,
+ 0xff0001ffff000001, 0xff0001ffff010000, 0xff0001ff00ffff00, 0xff0001ff00ff00ff,
+ 0xff0001ff00ff0001, 0xff0001ff00ff0100, 0xff0001ff0000ffff, 0xff0001ff00000000,
+ 0xff0001ff000001ff, 0xff0001ff00000101, 0xff0001ff0001ffff, 0xff0001ff0001ff00,
+ 0xff0001ff000100ff, 0xff0001ff00010001, 0xff0001ff00010100, 0xff0001ff01ff0000,
+ 0xff0001ff0100ff00, 0xff0001ff010000ff, 0xff0001ff01010000, 0xff000100ff00ffff,
+ 0xff000100ff00ff01, 0xff000100ff000000, 0xff000100ff000101, 0xff000100ff01ff00,
+ 0xff000100ff010000, 0xff00010000ffff01, 0xff00010000ff00ff, 0xff00010000ff0000,
+ 0xff00010000ff01ff, 0xff0001000000ff00, 0xff000100000000ff, 0xff00010000000000,
+ 0xff00010000000001, 0xff00010000000100, 0xff00010000000101, 0xff0001000001ffff,
+ 0xff00010000010000, 0xff00010000010101, 0xff00010001ff0100, 0xff0001000100ff00,
+ 0xff0001000100ff01, 0xff00010001000000, 0xff000100010001ff, 0xff0001000101ff00,
+ 0xff00010001010001, 0xff00010001010100, 0xff000101ffff0100, 0xff000101ff000001,
+ 0xff000101ff0100ff, 0xff000101ff010001, 0xff00010100ff00ff, 0xff00010100ff0001,
+ 0xff00010100ff0100, 0xff0001010000ffff, 0xff0001010000ff01, 0xff00010100000000,
+ 0xff000101000001ff, 0xff0001010001ff00, 0xff00010100010001, 0xff00010100010100,
+ 0xff00010101ff0000, 0xff0001010100ff00, 0xff00010101000001, 0xff00010101000101,
+ 0xff01ffffffffffff, 0xff01ffffffffff01, 0xff01ffffffff01ff, 0xff01ffffffff0101,
+ 0xff01ffffff000000, 0xff01ffffff01ffff, 0xff01ffffff01ff01, 0xff01ffffff010000,
+ 0xff01ffffff0101ff, 0xff01ffffff010101, 0xff01ffff00ff0000, 0xff01ffff0000ff00,
+ 0xff01ffff00000100, 0xff01ffff0001ff00, 0xff01ffff00010000, 0xff01ffff01ffffff,
+ 0xff01ffff01ffff01, 0xff01ffff01ff01ff, 0xff01ffff01ff0101, 0xff01ffff01000000,
+ 0xff01ffff0101ffff, 0xff01ffff0101ff01, 0xff01ffff01010000, 0xff01ffff010101ff,
+ 0xff01ffff01010101, 0xff01ff00ffff0000, 0xff01ff00ff00ff00, 0xff01ff00ff0000ff,
+ 0xff01ff00ff000100, 0xff01ff00ff010000, 0xff01ff0000ffff01, 0xff01ff0000ff00ff,
+ 0xff01ff0000ff0100, 0xff01ff0000000000, 0xff01ff00000001ff, 0xff01ff0000000101,
+ 0xff01ff000001ff00, 0xff01ff00000100ff, 0xff01ff0000010000, 0xff01ff0000010001,
+ 0xff01ff0001ff0000, 0xff01ff000100ffff, 0xff01ff0001000001, 0xff01ff0001000100,
+ 0xff01ff0001010000, 0xff01ff01ffffff00, 0xff01ff01ffff01ff, 0xff01ff01ffff0101,
+ 0xff01ff01ff00ff00, 0xff01ff01ff000000, 0xff01ff01ff01ffff, 0xff01ff01ff01ff01,
+ 0xff01ff01ff0101ff, 0xff01ff01ff010101, 0xff01ff0100ff0000, 0xff01ff010000ff00,
+ 0xff01ff0100000001, 0xff01ff0100000100, 0xff01ff0100010000, 0xff01ff0101ffff00,
+ 0xff01ff0101ff01ff, 0xff01ff0101ff0101, 0xff01ff010100ff00, 0xff01ff0101000000,
+ 0xff01ff010101ffff, 0xff01ff010101ff01, 0xff01ff01010101ff, 0xff01ff0101010101,
+ 0xff0100ffffff0000, 0xff0100ffff0000ff, 0xff0100ffff000001, 0xff0100ffff000100,
+ 0xff0100ffff010000, 0xff0100ff00ff00ff, 0xff0100ff00ff0000, 0xff0100ff00ff0001,
+ 0xff0100ff00ff0100, 0xff0100ff0000ff01, 0xff0100ff00000000, 0xff0100ff000001ff,
+ 0xff0100ff00000101, 0xff0100ff00010001, 0xff0100ff01ff0000, 0xff0100ff0100ff00,
+ 0xff0100ff010000ff, 0xff0100ff01000100, 0xff0100ff0101ff00, 0xff0100ff01010000,
+ 0xff010000ffff0100, 0xff010000ff000000, 0xff010000ff01ff00, 0xff010000ff010100,
+ 0xff01000000ffffff, 0xff01000000ff0000, 0xff01000000ff01ff, 0xff0100000000ff00,
+ 0xff010000000000ff, 0xff01000000000000, 0xff01000000000100, 0xff0100000001ff01,
+ 0xff01000000010000, 0xff010000000101ff, 0xff01000001ff0100, 0xff0100000100ffff,
+ 0xff010000010000ff, 0xff01000001000000, 0xff010000010001ff, 0xff01000001000101,
+ 0xff0100000101ff00, 0xff010000010100ff, 0xff01000001010001, 0xff01000001010100,
+ 0xff010001ffff0000, 0xff010001ff00ffff, 0xff010001ff00ff01, 0xff010001ff000100,
+ 0xff010001ff010000, 0xff01000100ffff00, 0xff01000100ff0100, 0xff01000100000000,
+ 0xff0100010001ffff, 0xff0100010001ff00, 0xff01000100010100, 0xff01000101ff00ff,
+ 0xff01000101ff0001, 0xff0100010100ffff, 0xff01000101000101, 0xff0101ffffffffff,
+ 0xff0101ffffffff01, 0xff0101ffffff01ff, 0xff0101ffffff0101, 0xff0101ffff000000,
+ 0xff0101ffff01ffff, 0xff0101ffff01ff01, 0xff0101ffff0101ff, 0xff0101ffff010101,
+ 0xff0101ff00ff0000, 0xff0101ff0000ff00, 0xff0101ff000000ff, 0xff0101ff00010000,
+ 0xff0101ff01ffffff, 0xff0101ff01ffff01, 0xff0101ff01ff01ff, 0xff0101ff01ff0101,
+ 0xff0101ff0101ffff, 0xff0101ff0101ff01, 0xff0101ff010101ff, 0xff0101ff01010101,
+ 0xff010100ffff0100, 0xff010100ff00ff00, 0xff010100ff0000ff, 0xff010100ff000100,
+ 0xff010100ff010000, 0xff01010000ff0001, 0xff01010000ff0100, 0xff0101000000ff01,
+ 0xff01010000000000, 0xff0101000001ff00, 0xff010100000100ff, 0xff01010000010001,
+ 0xff01010000010100, 0xff01010001ff0000, 0xff0101000100ffff, 0xff01010001000001,
+ 0xff01010001000100, 0xff010100010100ff, 0xff01010001010000, 0xff010101ffffffff,
+ 0xff010101ffffff01, 0xff010101ffff01ff, 0xff010101ffff0101, 0xff010101ff01ffff,
+ 0xff010101ff01ff01, 0xff010101ff0101ff, 0xff010101ff010101, 0xff01010100ff0000,
+ 0xff0101010000ff00, 0xff01010100000001, 0xff01010100000100, 0xff01010100010000,
+ 0xff01010101ffffff, 0xff01010101ffff01, 0xff01010101ff01ff, 0xff01010101ff0101,
+ 0xff01010101000000, 0xff0101010101ffff, 0xff0101010101ff01, 0xff010101010101ff,
+ 0xff01010101010101, 0x00ffffffffff0000, 0x00ffffffff00ff00, 0x00ffffffff000001,
+ 0x00ffffffff010000, 0x00ffffff00ff0100, 0x00ffffff0000ff01, 0x00ffffff00000000,
+ 0x00ffffff000001ff, 0x00ffffff00000101, 0x00ffffff0001ff00, 0x00ffffff000100ff,
+ 0x00ffffff00010001, 0x00ffffff010000ff, 0x00ffffff01000100, 0x00ffffff0101ff00,
+ 0x00ffffff01010001, 0x00ffff00ffffffff, 0x00ffff00ffffff00, 0x00ffff00ffff00ff,
+ 0x00ffff00ffff0001, 0x00ffff00ffff0100, 0x00ffff00ff00ff01, 0x00ffff00ff000000,
+ 0x00ffff00ff000001, 0x00ffff00ff0001ff, 0x00ffff00ff000101, 0x00ffff00ff01ff00,
+ 0x00ffff00ff010001, 0x00ffff00ff010100, 0x00ffff0000ff0000, 0x00ffff0000ff01ff,
+ 0x00ffff0000ff0101, 0x00ffff000000ff00, 0x00ffff00000000ff, 0x00ffff0000000000,
+ 0x00ffff0000000001, 0x00ffff0000000100, 0x00ffff0000000101, 0x00ffff0000010000,
+ 0x00ffff00000101ff, 0x00ffff0000010101, 0x00ffff0001ffff00, 0x00ffff0001ff00ff,
+ 0x00ffff0001ff0001, 0x00ffff000100ffff, 0x00ffff000100ff01, 0x00ffff0001000000,
+ 0x00ffff000101ffff, 0x00ffff000101ff00, 0x00ffff000101ff01, 0x00ffff01ffff0000,
+ 0x00ffff01ff00ff00, 0x00ffff01ff0000ff, 0x00ffff01ff000001, 0x00ffff01ff010000,
+ 0x00ffff0100ffff00, 0x00ffff010000ff01, 0x00ffff0100000000, 0x00ffff0100000101,
+ 0x00ffff01000100ff, 0x00ffff0100010100, 0x00ffff0101ff0100, 0x00ffff01010000ff,
+ 0x00ffff0101010000, 0x00ff00ffffffff00, 0x00ff00ffff000000, 0x00ff00ffff000100,
+ 0x00ff00ffff010100, 0x00ff00ff00ff0000, 0x00ff00ff00ff01ff, 0x00ff00ff00ff0101,
+ 0x00ff00ff0000ff00, 0x00ff00ff000000ff, 0x00ff00ff00000000, 0x00ff00ff00000001,
+ 0x00ff00ff0001ff00, 0x00ff00ff0001ff01, 0x00ff00ff00010000, 0x00ff00ff000101ff,
+ 0x00ff00ff00010101, 0x00ff00ff01ffff00, 0x00ff00ff01ff0001, 0x00ff00ff01ff0100,
+ 0x00ff00ff0100ffff, 0x00ff00ff0100ff01, 0x00ff00ff01000000, 0x00ff00ff0101ffff,
+ 0x00ff00ff0101ff00, 0x00ff00ff01010100, 0x00ff0000ffffff00, 0x00ff0000ffffff01,
+ 0x00ff0000ffff0000, 0x00ff0000ffff0101, 0x00ff0000ff00ff00, 0x00ff0000ff0000ff,
+ 0x00ff0000ff000000, 0x00ff0000ff000001, 0x00ff0000ff000100, 0x00ff0000ff01ffff,
+ 0x00ff0000ff010000, 0x00ff0000ff010101, 0x00ff000000ffff00, 0x00ff000000ff00ff,
+ 0x00ff000000ff0000, 0x00ff000000ff0001, 0x00ff000000ff0100, 0x00ff00000000ffff,
+ 0x00ff00000000ff00, 0x00ff0000000000ff, 0x00ff000000000000, 0x00ff000000000001,
+ 0x00ff0000000001ff, 0x00ff000000000100, 0x00ff00000001ff00, 0x00ff0000000100ff,
+ 0x00ff000000010000, 0x00ff000000010001, 0x00ff000000010100, 0x00ff000001ffff01,
+ 0x00ff000001ff00ff, 0x00ff000001ff0000, 0x00ff000001ff01ff, 0x00ff00000100ff00,
+ 0x00ff0000010000ff, 0x00ff000001000000, 0x00ff000001000001, 0x00ff000001000100,
+ 0x00ff000001000101, 0x00ff000001010000, 0x00ff0000010101ff, 0x00ff000001010101,
+ 0x00ff0001ffffff00, 0x00ff0001ffff0000, 0x00ff0001ffff0100, 0x00ff0001ff0000ff,
+ 0x00ff0001ff000000, 0x00ff0001ff0001ff, 0x00ff0001ff000101, 0x00ff0001ff01ff00,
+ 0x00ff0001ff0100ff, 0x00ff0001ff010100, 0x00ff000100ffffff, 0x00ff000100ffff01,
+ 0x00ff000100ff0000, 0x00ff000100ff01ff, 0x00ff00010000ffff, 0x00ff00010000ff00,
+ 0x00ff00010000ff01, 0x00ff000100000000, 0x00ff000100000001, 0x00ff000100000100,
+ 0x00ff00010001ff01, 0x00ff000100010000, 0x00ff0001000101ff, 0x00ff000101ffff00,
+ 0x00ff000101ff0000, 0x00ff000101ff0101, 0x00ff0001010000ff, 0x00ff000101000000,
+ 0x00ff00010101ff00, 0x00ff0001010100ff, 0x00ff000101010001, 0x00ff01ffffff0000,
+ 0x00ff01ffff00ff00, 0x00ff01ffff000000, 0x00ff01ffff000101, 0x00ff01ffff010000,
+ 0x00ff01ff00ffff01, 0x00ff01ff00ff0100, 0x00ff01ff0000ffff, 0x00ff01ff00000000,
+ 0x00ff01ff000001ff, 0x00ff01ff0001ff00, 0x00ff01ff000100ff, 0x00ff01ff00010001,
+ 0x00ff01ff00010100, 0x00ff01ff01ff0000, 0x00ff01ff0100ff00, 0x00ff01ff010000ff,
+ 0x00ff01ff01000001, 0x00ff01ff01000100, 0x00ff01ff01010000, 0x00ff0100ffffff00,
+ 0x00ff0100ffff0000, 0x00ff0100ffff0001, 0x00ff0100ffff0101, 0x00ff0100ff00ffff,
+ 0x00ff0100ff0000ff, 0x00ff0100ff000000, 0x00ff0100ff0001ff, 0x00ff0100ff01ff00,
+ 0x00ff0100ff0100ff, 0x00ff0100ff010001, 0x00ff010000ffffff, 0x00ff010000ff0000,
+ 0x00ff010000ff0101, 0x00ff01000000ff00, 0x00ff01000000ff01, 0x00ff0100000000ff,
+ 0x00ff010000000000, 0x00ff010000000001, 0x00ff010000000100, 0x00ff01000001ffff,
+ 0x00ff01000001ff01, 0x00ff010000010000, 0x00ff010000010001, 0x00ff010000010101,
+ 0x00ff010001ff0001, 0x00ff010001ff0100, 0x00ff01000100ff01, 0x00ff010001000000,
+ 0x00ff010001000001, 0x00ff0100010001ff, 0x00ff01000101ff00, 0x00ff0100010100ff,
+ 0x00ff010001010001, 0x00ff010001010100, 0x00ff0101ff000001, 0x00ff010100ff00ff,
+ 0x00ff010100ff0001, 0x00ff010100ff0100, 0x00ff010100000000, 0x00ff0101000001ff,
+ 0x00ff010100000101, 0x00ff0101000100ff, 0x00ff010100010100, 0x00ff0101010000ff,
+ 0x00ff010101010000, 0x0000ffffffffff00, 0x0000ffffffff00ff, 0x0000ffffffff0000,
+ 0x0000ffffffff0001, 0x0000ffffffff0100, 0x0000ffffff00ff01, 0x0000ffffff000000,
+ 0x0000ffffff000101, 0x0000ffffff01ff00, 0x0000ffffff0100ff, 0x0000ffffff010100,
+ 0x0000ffff00ffffff, 0x0000ffff00ff0000, 0x0000ffff00ff01ff, 0x0000ffff0000ff00,
+ 0x0000ffff000000ff, 0x0000ffff00000000, 0x0000ffff00000001, 0x0000ffff00000100,
+ 0x0000ffff00010000, 0x0000ffff000101ff, 0x0000ffff01ff0001, 0x0000ffff01ff0100,
+ 0x0000ffff01000000, 0x0000ffff010001ff, 0x0000ffff0101ffff, 0x0000ffff0101ff00,
+ 0x0000ffff01010001, 0x0000ffff01010100, 0x0000ff00ffff0000, 0x0000ff00ffff01ff,
+ 0x0000ff00ffff0100, 0x0000ff00ffff0101, 0x0000ff00ff00ff00, 0x0000ff00ff0000ff,
+ 0x0000ff00ff000000, 0x0000ff00ff000001, 0x0000ff00ff0001ff, 0x0000ff00ff000100,
+ 0x0000ff00ff01ffff, 0x0000ff00ff010000, 0x0000ff00ff010001, 0x0000ff00ff0101ff,
+ 0x0000ff00ff010101, 0x0000ff0000ffff00, 0x0000ff0000ff00ff, 0x0000ff0000ff0000,
+ 0x0000ff0000ff0001, 0x0000ff0000ff0100, 0x0000ff000000ffff, 0x0000ff000000ff00,
+ 0x0000ff000000ff01, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
+ 0x0000ff00000001ff, 0x0000ff0000000100, 0x0000ff0000000101, 0x0000ff000001ff00,
+ 0x0000ff00000100ff, 0x0000ff0000010000, 0x0000ff0000010001, 0x0000ff0000010100,
+ 0x0000ff0001ffff01, 0x0000ff0001ff0000, 0x0000ff000100ff00, 0x0000ff00010000ff,
+ 0x0000ff0001000000, 0x0000ff0001000001, 0x0000ff0001000100, 0x0000ff000101ffff,
+ 0x0000ff0001010000, 0x0000ff0001010101, 0x0000ff01ffffff00, 0x0000ff01ffff0001,
+ 0x0000ff01ff00ff01, 0x0000ff01ff000000, 0x0000ff01ff000101, 0x0000ff01ff01ff00,
+ 0x0000ff01ff0100ff, 0x0000ff0100ffff01, 0x0000ff0100ff0000, 0x0000ff0100ff0101,
+ 0x0000ff010000ff00, 0x0000ff01000000ff, 0x0000ff0100000000, 0x0000ff0100000001,
+ 0x0000ff0100000100, 0x0000ff010001ff01, 0x0000ff0100010000, 0x0000ff0101ff0000,
+ 0x0000ff010100ffff, 0x0000ff010100ff01, 0x0000ff0101000000, 0x0000ff0101000100,
+ 0x0000ff0101000101, 0x0000ff01010100ff, 0x000000ffffff00ff, 0x000000ffffff0000,
+ 0x000000ffff00ff00, 0x000000ffff0000ff, 0x000000ffff000000, 0x000000ffff000001,
+ 0x000000ffff0001ff, 0x000000ffff000100, 0x000000ffff01ff00, 0x000000ffff010000,
+ 0x000000ffff0101ff, 0x000000ffff010101, 0x000000ff00ffff00, 0x000000ff00ff00ff,
+ 0x000000ff00ff0000, 0x000000ff00ff0001, 0x000000ff00ff0100, 0x000000ff00ff0101,
+ 0x000000ff0000ffff, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
+ 0x000000ff00000001, 0x000000ff000001ff, 0x000000ff00000100, 0x000000ff00000101,
+ 0x000000ff0001ff00, 0x000000ff0001ff01, 0x000000ff000100ff, 0x000000ff00010000,
+ 0x000000ff00010001, 0x000000ff00010100, 0x000000ff01ffffff, 0x000000ff01ff01ff,
+ 0x000000ff01ff0101, 0x000000ff0100ff00, 0x000000ff010000ff, 0x000000ff01000000,
+ 0x000000ff01000001, 0x000000ff01000100, 0x000000ff0101ff00, 0x000000ff010100ff,
+ 0x000000ff01010000, 0x000000ff01010101, 0x00000000ffffff00, 0x00000000ffffff01,
+ 0x00000000ffff00ff, 0x00000000ffff0000, 0x00000000ffff0001, 0x00000000ffff0100,
+ 0x00000000ff00ffff, 0x00000000ff00ff00, 0x00000000ff00ff01, 0x00000000ff0000ff,
+ 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff000101,
+ 0x00000000ff01ff00, 0x00000000ff0100ff, 0x00000000ff010000, 0x00000000ff010001,
+ 0x00000000ff010100, 0x0000000000ffffff, 0x0000000000ffff00, 0x0000000000ffff01,
+ 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001, 0x0000000000ff01ff,
+ 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
+ 0x00000000000000ff, 0x0000000000000000, 0x0000000000000001, 0x00000000000001ff,
+ 0x0000000000000100, 0x0000000000000101, 0x000000000001ffff, 0x000000000001ff00,
+ 0x00000000000100ff, 0x0000000000010000, 0x0000000000010001, 0x00000000000101ff,
+ 0x0000000000010100, 0x0000000000010101, 0x0000000001ffff00, 0x0000000001ff00ff,
+ 0x0000000001ff0000, 0x0000000001ff0100, 0x0000000001ff0101, 0x000000000100ffff,
+ 0x000000000100ff00, 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001,
+ 0x00000000010001ff, 0x0000000001000100, 0x000000000101ff00, 0x00000000010100ff,
+ 0x0000000001010000, 0x0000000001010001, 0x0000000001010100, 0x00000001ffffffff,
+ 0x00000001ffffff00, 0x00000001ffffff01, 0x00000001ffff00ff, 0x00000001ffff0001,
+ 0x00000001ffff01ff, 0x00000001ffff0100, 0x00000001ff00ff00, 0x00000001ff0000ff,
+ 0x00000001ff000000, 0x00000001ff0001ff, 0x00000001ff000100, 0x00000001ff01ffff,
+ 0x00000001ff01ff00, 0x00000001ff01ff01, 0x00000001ff0100ff, 0x00000001ff010000,
+ 0x00000001ff010001, 0x00000001ff0101ff, 0x00000001ff010100, 0x0000000100ffff00,
+ 0x0000000100ff0000, 0x0000000100ff0001, 0x0000000100ff01ff, 0x0000000100ff0100,
+ 0x0000000100ff0101, 0x000000010000ffff, 0x000000010000ff00, 0x000000010000ff01,
+ 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001, 0x00000001000001ff,
+ 0x0000000100000100, 0x0000000100000101, 0x000000010001ff00, 0x00000001000100ff,
+ 0x0000000100010000, 0x0000000100010100, 0x0000000101ffff01, 0x0000000101ff0000,
+ 0x0000000101ff0001, 0x0000000101ff01ff, 0x0000000101ff0100, 0x0000000101ff0101,
+ 0x000000010100ff00, 0x0000000101000000, 0x0000000101000101, 0x000000010101ff01,
+ 0x0000000101010000, 0x0000000101010001, 0x00000001010101ff, 0x0000000101010100,
+ 0x000001ffffff00ff, 0x000001ffffff0000, 0x000001ffffff0001, 0x000001ffffff0100,
+ 0x000001ffff00ffff, 0x000001ffff000000, 0x000001ffff0001ff, 0x000001ffff01ff00,
+ 0x000001ffff010101, 0x000001ff00ff0000, 0x000001ff00ff01ff, 0x000001ff00ff0101,
+ 0x000001ff0000ff00, 0x000001ff000000ff, 0x000001ff00000000, 0x000001ff00000001,
+ 0x000001ff000001ff, 0x000001ff00000100, 0x000001ff0001ffff, 0x000001ff0001ff01,
+ 0x000001ff000100ff, 0x000001ff00010000, 0x000001ff01ffff01, 0x000001ff01ff0100,
+ 0x000001ff0100ffff, 0x000001ff0100ff01, 0x000001ff01000000, 0x000001ff010001ff,
+ 0x000001ff0101ff00, 0x000001ff01010100, 0x00000100ffffff00, 0x00000100ffffff01,
+ 0x00000100ffff0000, 0x00000100ffff0101, 0x00000100ff00ff00, 0x00000100ff0000ff,
+ 0x00000100ff000000, 0x00000100ff000001, 0x00000100ff000100, 0x00000100ff010000,
+ 0x0000010000ffff00, 0x0000010000ff00ff, 0x0000010000ff0000, 0x0000010000ff0001,
+ 0x0000010000ff0100, 0x000001000000ffff, 0x000001000000ff00, 0x000001000000ff01,
+ 0x00000100000000ff, 0x0000010000000000, 0x0000010000000001, 0x00000100000001ff,
+ 0x0000010000000100, 0x0000010000000101, 0x000001000001ff00, 0x00000100000100ff,
+ 0x0000010000010000, 0x0000010000010001, 0x0000010000010100, 0x0000010001ffff00,
+ 0x0000010001ff0000, 0x0000010001ff0100, 0x000001000100ff00, 0x00000100010000ff,
+ 0x0000010001000000, 0x0000010001000001, 0x00000100010001ff, 0x0000010001000100,
+ 0x0000010001010000, 0x00000101ffff00ff, 0x00000101ffff01ff, 0x00000101ff000000,
+ 0x00000101ff000101, 0x00000101ff01ffff, 0x00000101ff010000, 0x00000101ff010001,
+ 0x00000101ff010100, 0x0000010100ff0000, 0x0000010100ff01ff, 0x0000010100ff0100,
+ 0x000001010000ff00, 0x0000010100000000, 0x0000010100000001, 0x00000101000001ff,
+ 0x0000010100000100, 0x000001010001ff01, 0x0000010100010000, 0x00000101000101ff,
+ 0x0000010100010101, 0x0000010101ffff00, 0x0000010101ff0101, 0x000001010100ff01,
+ 0x0000010101000000, 0x0000010101000001, 0x00000101010001ff, 0x0000010101000101,
+ 0x000001010101ff00, 0x0001ffffffff0000, 0x0001ffffff0000ff, 0x0001ffffff000001,
+ 0x0001ffffff000100, 0x0001ffffff010000, 0x0001ffff00ff00ff, 0x0001ffff0000ffff,
+ 0x0001ffff00000000, 0x0001ffff00000001, 0x0001ffff000001ff, 0x0001ffff00000101,
+ 0x0001ffff0001ff00, 0x0001ffff000100ff, 0x0001ffff00010001, 0x0001ffff00010100,
+ 0x0001ffff01ffff00, 0x0001ffff01000001, 0x0001ffff01010000, 0x0001ff00ffffff00,
+ 0x0001ff00ffff00ff, 0x0001ff00ffff0001, 0x0001ff00ffff0100, 0x0001ff00ff00ff01,
+ 0x0001ff00ff000000, 0x0001ff00ff01ff00, 0x0001ff00ff01ff01, 0x0001ff00ff010001,
+ 0x0001ff00ff010100, 0x0001ff0000ff0000, 0x0001ff0000ff0100, 0x0001ff000000ff00,
+ 0x0001ff0000000000, 0x0001ff0000000001, 0x0001ff0000000100, 0x0001ff0000010000,
+ 0x0001ff0000010001, 0x0001ff0000010101, 0x0001ff0001ff00ff, 0x0001ff0001ff0101,
+ 0x0001ff000100ff01, 0x0001ff0001000000, 0x0001ff000101ff00, 0x0001ff0001010001,
+ 0x0001ff0001010100, 0x0001ff01ff00ff00, 0x0001ff01ff000001, 0x0001ff01ff000100,
+ 0x0001ff0100ffffff, 0x0001ff0100ffff00, 0x0001ff0100ff0001, 0x0001ff0100000000,
+ 0x0001ff0100000001, 0x0001ff01000001ff, 0x0001ff010001ffff, 0x0001ff0101ff0000,
+ 0x0001ff010100ff00, 0x0001ff0101000001, 0x0001ff0101010000, 0x000100ffff00ff00,
+ 0x000100ffff00ff01, 0x000100ffff000000, 0x000100ffff000001, 0x000100ffff000101,
+ 0x000100ffff01ff00, 0x000100ffff010001, 0x000100ffff010100, 0x000100ff00ffffff,
+ 0x000100ff00ffff01, 0x000100ff00ff0000, 0x000100ff00ff01ff, 0x000100ff00ff0101,
+ 0x000100ff0000ff00, 0x000100ff000000ff, 0x000100ff00000000, 0x000100ff00000001,
+ 0x000100ff00000100, 0x000100ff00000101, 0x000100ff0001ffff, 0x000100ff0001ff01,
+ 0x000100ff00010000, 0x000100ff01ff00ff, 0x000100ff01ff0000, 0x000100ff01ff0100,
+ 0x000100ff0100ffff, 0x000100ff0100ff01, 0x000100ff010000ff, 0x000100ff01000000,
+ 0x000100ff01000001, 0x000100ff010001ff, 0x000100ff01000101, 0x000100ff0101ff00,
+ 0x000100ff010100ff, 0x000100ff01010100, 0x00010000ffff0000, 0x00010000ffff01ff,
+ 0x00010000ffff0101, 0x00010000ff00ff00, 0x00010000ff000000, 0x00010000ff000001,
+ 0x00010000ff000100, 0x0001000000ff00ff, 0x0001000000ff0000, 0x0001000000ff0001,
+ 0x0001000000ff0100, 0x000100000000ffff, 0x000100000000ff00, 0x00010000000000ff,
+ 0x0001000000000000, 0x0001000000000001, 0x0001000000000100, 0x000100000001ff00,
+ 0x00010000000100ff, 0x0001000000010000, 0x0001000000010001, 0x0001000000010100,
+ 0x0001000001ff0001, 0x0001000001ff0100, 0x0001000001ff0101, 0x000100000100ff00,
+ 0x0001000001000000, 0x0001000001000001, 0x0001000001000100, 0x0001000001000101,
+ 0x000100000101ff01, 0x0001000001010000, 0x0001000001010001, 0x00010000010101ff,
+ 0x00010001ffffff01, 0x00010001ffff0100, 0x00010001ff000000, 0x00010001ff01ffff,
+ 0x00010001ff010001, 0x00010001ff0101ff, 0x00010001ff010100, 0x0001000100ffffff,
+ 0x0001000100ff0000, 0x0001000100ff01ff, 0x0001000100ff0101, 0x000100010000ff00,
+ 0x00010001000000ff, 0x0001000100000000, 0x0001000100000001, 0x00010001000001ff,
+ 0x0001000100000101, 0x000100010001ffff, 0x0001000100010000, 0x00010001000101ff,
+ 0x0001000101ffffff, 0x0001000101ffff01, 0x0001000101ff0000, 0x0001000101ff0101,
+ 0x00010001010000ff, 0x0001000101000001, 0x00010001010001ff, 0x0001000101000100,
+ 0x000100010101ffff, 0x00010001010100ff, 0x0001000101010001, 0x0001000101010101,
+ 0x000101ffff000001, 0x000101ffff000100, 0x000101ffff010000, 0x000101ff00ffff00,
+ 0x000101ff0000ff01, 0x000101ff00000000, 0x000101ff00000101, 0x000101ff0001ff00,
+ 0x000101ff00010100, 0x000101ff01ff0000, 0x000101ff0100ff00, 0x000101ff010001ff,
+ 0x000101ff01010001, 0x00010100ffffff00, 0x00010100ffff00ff, 0x00010100ff00ffff,
+ 0x00010100ff000000, 0x00010100ff01ff00, 0x00010100ff0100ff, 0x00010100ff010001,
+ 0x00010100ff010100, 0x0001010000ffffff, 0x0001010000ffff00, 0x0001010000ff0000,
+ 0x0001010000ff0001, 0x0001010000ff01ff, 0x000101000000ff00, 0x00010100000000ff,
+ 0x0001010000000000, 0x0001010000000001, 0x0001010000000100, 0x000101000001ffff,
+ 0x0001010000010000, 0x0001010000010101, 0x0001010001ffff01, 0x0001010001ff00ff,
+ 0x0001010001ff0101, 0x0001010001000000, 0x000101000101ff00, 0x00010100010100ff,
+ 0x0001010001010000, 0x0001010001010100, 0x00010101ff00ff00, 0x00010101ff000001,
+ 0x00010101ff0001ff, 0x0001010100ffff00, 0x0001010100ff00ff, 0x0001010100ff0100,
+ 0x000101010000ffff, 0x0001010100000000, 0x00010101000001ff, 0x0001010100000101,
+ 0x00010101000100ff, 0x0001010100010000, 0x0001010100010100, 0x0001010101ff0001,
+ 0x00010101010000ff, 0x00010101010001ff, 0x0001010101000101, 0x0001010101010001,
+ 0x01ffffffffffffff, 0x01ffffffffffff01, 0x01ffffffffff01ff, 0x01ffffffffff0101,
+ 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff, 0x01ffffffff010101,
+ 0x01ffffff00ff0000, 0x01ffffff0000ffff, 0x01ffffff0000ff00, 0x01ffffff000000ff,
+ 0x01ffffff00000001, 0x01ffffff00000100, 0x01ffffff00010000, 0x01ffffff01ffffff,
+ 0x01ffffff01ffff01, 0x01ffffff01ff01ff, 0x01ffffff01ff0101, 0x01ffffff01000000,
+ 0x01ffffff0101ffff, 0x01ffffff0101ff01, 0x01ffffff010101ff, 0x01ffffff01010101,
+ 0x01ffff00ffff0000, 0x01ffff00ff00ff00, 0x01ffff00ff0000ff, 0x01ffff00ff000001,
+ 0x01ffff00ff000100, 0x01ffff00ff010000, 0x01ffff0000ffff00, 0x01ffff0000ff00ff,
+ 0x01ffff0000ff0100, 0x01ffff000000ffff, 0x01ffff000000ff01, 0x01ffff0000000000,
+ 0x01ffff0000000001, 0x01ffff00000001ff, 0x01ffff0000000100, 0x01ffff00000100ff,
+ 0x01ffff0000010001, 0x01ffff0000010100, 0x01ffff0001ff0000, 0x01ffff0001ff0100,
+ 0x01ffff00010000ff, 0x01ffff0001000001, 0x01ffff0001000100, 0x01ffff0001010000,
+ 0x01ffff01ffffffff, 0x01ffff01ffffff01, 0x01ffff01ffff01ff, 0x01ffff01ffff0101,
+ 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff01ff01, 0x01ffff01ff0101ff,
+ 0x01ffff01ff010101, 0x01ffff010000ff00, 0x01ffff01000000ff, 0x01ffff0100000100,
+ 0x01ffff0100010000, 0x01ffff0101ffffff, 0x01ffff0101ffff01, 0x01ffff0101ff01ff,
+ 0x01ffff0101ff0101, 0x01ffff0101000000, 0x01ffff010101ffff, 0x01ffff010101ff01,
+ 0x01ffff01010101ff, 0x01ffff0101010101, 0x01ff00ffff0000ff, 0x01ff00ffff000100,
+ 0x01ff00ff00ffff00, 0x01ff00ff00ff00ff, 0x01ff00ff0000ff00, 0x01ff00ff00000000,
+ 0x01ff00ff00000101, 0x01ff00ff0001ff00, 0x01ff00ff000100ff, 0x01ff00ff00010100,
+ 0x01ff00ff010000ff, 0x01ff00ff01000100, 0x01ff0000ffffff00, 0x01ff0000ffff0100,
+ 0x01ff0000ff00ff01, 0x01ff0000ff000000, 0x01ff0000ff000101, 0x01ff0000ff010001,
+ 0x01ff0000ff010100, 0x01ff000000ffffff, 0x01ff000000ffff00, 0x01ff000000ff0000,
+ 0x01ff000000ff01ff, 0x01ff00000000ff00, 0x01ff0000000000ff, 0x01ff000000000000,
+ 0x01ff000000000001, 0x01ff000000000100, 0x01ff000000000101, 0x01ff000000010000,
+ 0x01ff000000010001, 0x01ff0000000101ff, 0x01ff000000010101, 0x01ff000001ffff00,
+ 0x01ff000001ff00ff, 0x01ff000001ff0001, 0x01ff000001ff0100, 0x01ff00000100ffff,
+ 0x01ff00000100ff01, 0x01ff000001000000, 0x01ff0000010001ff, 0x01ff000001010001,
+ 0x01ff0001ff00ff00, 0x01ff0001ff000001, 0x01ff0001ff000100, 0x01ff0001ff010000,
+ 0x01ff000100ffff00, 0x01ff000100ff00ff, 0x01ff000100ff0100, 0x01ff000100ff0101,
+ 0x01ff00010000ffff, 0x01ff000100000000, 0x01ff000100000100, 0x01ff000100000101,
+ 0x01ff00010001ff00, 0x01ff000100010001, 0x01ff000100010101, 0x01ff000101ff0000,
+ 0x01ff00010100ff00, 0x01ff000101000101, 0x01ff0001010100ff, 0x01ff01ffffffffff,
+ 0x01ff01ffffffff01, 0x01ff01ffffff01ff, 0x01ff01ffffff0101, 0x01ff01ffff000000,
+ 0x01ff01ffff01ffff, 0x01ff01ffff01ff01, 0x01ff01ffff0101ff, 0x01ff01ffff010101,
+ 0x01ff01ff00ffff00, 0x01ff01ff00ff0000, 0x01ff01ff0000ff00, 0x01ff01ff000000ff,
+ 0x01ff01ff00000100, 0x01ff01ff00010000, 0x01ff01ff00010100, 0x01ff01ff01ffffff,
+ 0x01ff01ff01ffff01, 0x01ff01ff01ff01ff, 0x01ff01ff01ff0101, 0x01ff01ff01000000,
+ 0x01ff01ff0101ffff, 0x01ff01ff0101ff01, 0x01ff01ff010101ff, 0x01ff01ff01010101,
+ 0x01ff0100ffff0000, 0x01ff0100ffff0001, 0x01ff0100ff00ff00, 0x01ff0100ff0000ff,
+ 0x01ff0100ff000001, 0x01ff0100ff010000, 0x01ff010000ffff00, 0x01ff010000ff00ff,
+ 0x01ff010000ff0001, 0x01ff010000ff0100, 0x01ff01000000ffff, 0x01ff01000000ff01,
+ 0x01ff010000000000, 0x01ff010000000101, 0x01ff01000001ff00, 0x01ff0100000100ff,
+ 0x01ff010001ff0000, 0x01ff010001000001, 0x01ff010001000100, 0x01ff010001010000,
+ 0x01ff0101ffffffff, 0x01ff0101ffffff01, 0x01ff0101ffff01ff, 0x01ff0101ffff0101,
+ 0x01ff0101ff000000, 0x01ff0101ff01ffff, 0x01ff0101ff01ff01, 0x01ff0101ff0101ff,
+ 0x01ff0101ff010101, 0x01ff010100ff0000, 0x01ff01010000ff00, 0x01ff0101000000ff,
+ 0x01ff010100000001, 0x01ff010101ffffff, 0x01ff010101ffff01, 0x01ff010101ff01ff,
+ 0x01ff010101ff0101, 0x01ff010101000000, 0x01ff01010101ffff, 0x01ff01010101ff01,
+ 0x01ff0101010101ff, 0x01ff010101010101, 0x0100ffffffff0000, 0x0100ffffff00ff00,
+ 0x0100ffffff000001, 0x0100ffffff0001ff, 0x0100ffffff000100, 0x0100ffffff010000,
+ 0x0100ffff00ffff00, 0x0100ffff00ff0001, 0x0100ffff00ff0100, 0x0100ffff00000000,
+ 0x0100ffff000001ff, 0x0100ffff00000101, 0x0100ffff00010100, 0x0100ffff00010101,
+ 0x0100ffff01ff0000, 0x0100ffff0100ff00, 0x0100ffff010000ff, 0x0100ffff01000001,
+ 0x0100ffff01000100, 0x0100ffff01010000, 0x0100ff00ffffff00, 0x0100ff00ffff00ff,
+ 0x0100ff00ffff0001, 0x0100ff00ffff0100, 0x0100ff00ff00ffff, 0x0100ff00ff000000,
+ 0x0100ff00ff0001ff, 0x0100ff00ff000101, 0x0100ff00ff01ff00, 0x0100ff00ff0100ff,
+ 0x0100ff00ff010001, 0x0100ff00ff010100, 0x0100ff0000ffffff, 0x0100ff0000ff0000,
+ 0x0100ff000000ffff, 0x0100ff000000ff00, 0x0100ff00000000ff, 0x0100ff0000000000,
+ 0x0100ff0000000001, 0x0100ff0000000100, 0x0100ff000001ff01, 0x0100ff0000010000,
+ 0x0100ff0001ff00ff, 0x0100ff0001ff0001, 0x0100ff000100ff01, 0x0100ff0001000000,
+ 0x0100ff00010001ff, 0x0100ff000101ff00, 0x0100ff00010100ff, 0x0100ff0001010001,
+ 0x0100ff0001010100, 0x0100ff01ffff0000, 0x0100ff01ff00ff00, 0x0100ff01ff0000ff,
+ 0x0100ff01ff000100, 0x0100ff01ff010000, 0x0100ff0100ff00ff, 0x0100ff0100ff0001,
+ 0x0100ff0100ff0100, 0x0100ff010000ffff, 0x0100ff010000ff01, 0x0100ff0100000000,
+ 0x0100ff01000001ff, 0x0100ff0100010001, 0x0100ff0100010100, 0x0100ff0101ff0000,
+ 0x0100ff01010000ff, 0x0100ff0101000001, 0x0100ff0101010100, 0x010000ffffffff00,
+ 0x010000ffffff00ff, 0x010000ffffff0001, 0x010000ffff00ffff, 0x010000ffff000000,
+ 0x010000ffff0001ff, 0x010000ffff010001, 0x010000ff00ffffff, 0x010000ff00ff0101,
+ 0x010000ff0000ff00, 0x010000ff000000ff, 0x010000ff00000000, 0x010000ff00000001,
+ 0x010000ff000001ff, 0x010000ff00000100, 0x010000ff0001ffff, 0x010000ff0001ff00,
+ 0x010000ff0001ff01, 0x010000ff00010000, 0x010000ff01ff00ff, 0x010000ff01ff0001,
+ 0x010000ff0100ff01, 0x010000ff010000ff, 0x010000ff01000000, 0x010000ff010001ff,
+ 0x010000ff0101ff00, 0x010000ff01010100, 0x01000000ffffffff, 0x01000000ffff0000,
+ 0x01000000ffff01ff, 0x01000000ffff0101, 0x01000000ff00ffff, 0x01000000ff00ff00,
+ 0x01000000ff0000ff, 0x01000000ff000000, 0x01000000ff000001, 0x01000000ff000100,
+ 0x01000000ff01ff00, 0x01000000ff010000, 0x01000000ff010100, 0x01000000ff010101,
+ 0x0100000000ffff00, 0x0100000000ff00ff, 0x0100000000ff0000, 0x0100000000ff0001,
+ 0x0100000000ff0100, 0x010000000000ffff, 0x010000000000ff00, 0x010000000000ff01,
+ 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001, 0x01000000000001ff,
+ 0x0100000000000100, 0x0100000000000101, 0x010000000001ff00, 0x01000000000100ff,
+ 0x0100000000010000, 0x0100000000010001, 0x0100000000010100, 0x0100000001ffff00,
+ 0x0100000001ff0000, 0x0100000001ff01ff, 0x010000000100ff00, 0x010000000100ff01,
+ 0x01000000010000ff, 0x0100000001000000, 0x0100000001000001, 0x0100000001000100,
+ 0x0100000001000101, 0x010000000101ffff, 0x010000000101ff01, 0x0100000001010000,
+ 0x01000000010101ff, 0x0100000001010101, 0x01000001ffffff00, 0x01000001ffff00ff,
+ 0x01000001ff00ffff, 0x01000001ff000000, 0x01000001ff000100, 0x01000001ff01ffff,
+ 0x01000001ff010001, 0x01000001ff010100, 0x0100000100ff0000, 0x0100000100ff01ff,
+ 0x0100000100ff0100, 0x010000010000ff00, 0x010000010000ff01, 0x0100000100000000,
+ 0x0100000100000001, 0x0100000100000100, 0x0100000100010000, 0x01000001000101ff,
+ 0x0100000101ffff01, 0x0100000101ff00ff, 0x0100000101ff0100, 0x0100000101ff0101,
+ 0x010000010100ff01, 0x01000001010000ff, 0x0100000101000000, 0x01000001010100ff,
+ 0x0100000101010001, 0x0100000101010100, 0x010001ffffff0000, 0x010001ffff000001,
+ 0x010001ffff000100, 0x010001ffff010000, 0x010001ff00ffff00, 0x010001ff00ff0001,
+ 0x010001ff0000ffff, 0x010001ff0000ff01, 0x010001ff00000000, 0x010001ff00000001,
+ 0x010001ff00000101, 0x010001ff000100ff, 0x010001ff00010000, 0x010001ff01ff0000,
+ 0x010001ff0100ff00, 0x010001ff01000001, 0x010001ff01000100, 0x010001ff01010000,
+ 0x01000100ffff00ff, 0x01000100ffff0001, 0x01000100ffff0100, 0x01000100ff00ffff,
+ 0x01000100ff00ff01, 0x01000100ff000000, 0x01000100ff0001ff, 0x01000100ff000101,
+ 0x01000100ff01ffff, 0x01000100ff01ff00, 0x01000100ff0100ff, 0x01000100ff010001,
+ 0x0100010000ffffff, 0x0100010000ffff01, 0x0100010000ff0000, 0x0100010000ff01ff,
+ 0x0100010000ff0101, 0x010001000000ff00, 0x01000100000000ff, 0x0100010000000000,
+ 0x0100010000000001, 0x0100010000000100, 0x010001000001ff01, 0x0100010000010000,
+ 0x0100010000010001, 0x0100010000010101, 0x0100010001ffff00, 0x0100010001ff00ff,
+ 0x010001000100ffff, 0x010001000100ff01, 0x0100010001000000, 0x0100010001000101,
+ 0x010001000101ff00, 0x0100010001010001, 0x01000101ffff0000, 0x01000101ff000000,
+ 0x01000101ff010000, 0x0100010100ff00ff, 0x0100010100ff0001, 0x0100010100ff0100,
+ 0x010001010000ffff, 0x0100010100000000, 0x01000101000001ff, 0x010001010001ff00,
+ 0x0100010101ff0000, 0x010001010100ff00, 0x01000101010000ff, 0x0100010101000000,
+ 0x0100010101000001, 0x0101ffffffffffff, 0x0101ffffffffff01, 0x0101ffffffff01ff,
+ 0x0101ffffffff0101, 0x0101ffffff000000, 0x0101ffffff01ffff, 0x0101ffffff01ff01,
+ 0x0101ffffff0101ff, 0x0101ffffff010101, 0x0101ffff00ff0000, 0x0101ffff0000ff00,
+ 0x0101ffff000000ff, 0x0101ffff00000001, 0x0101ffff00000100, 0x0101ffff01ffffff,
+ 0x0101ffff01ffff01, 0x0101ffff01ff01ff, 0x0101ffff01ff0101, 0x0101ffff01000000,
+ 0x0101ffff0101ffff, 0x0101ffff0101ff01, 0x0101ffff010101ff, 0x0101ffff01010101,
+ 0x0101ff00ffff0000, 0x0101ff00ffff0100, 0x0101ff00ff00ff00, 0x0101ff00ff0000ff,
+ 0x0101ff00ff000001, 0x0101ff00ff000100, 0x0101ff00ff000101, 0x0101ff0000ff0001,
+ 0x0101ff0000ff0100, 0x0101ff000000ff00, 0x0101ff0000000000, 0x0101ff00000001ff,
+ 0x0101ff0000000101, 0x0101ff000001ff00, 0x0101ff00000100ff, 0x0101ff0001ff0000,
+ 0x0101ff000100ffff, 0x0101ff000100ff01, 0x0101ff0001000001, 0x0101ff0001000100,
+ 0x0101ff01ffffff01, 0x0101ff01ffff01ff, 0x0101ff01ffff0101, 0x0101ff01ff00ffff,
+ 0x0101ff01ff000100, 0x0101ff01ff01ff01, 0x0101ff01ff0101ff, 0x0101ff01ff010101,
+ 0x0101ff0100ff0000, 0x0101ff010000ff00, 0x0101ff0100000001, 0x0101ff0100000100,
+ 0x0101ff0100010000, 0x0101ff0101ffffff, 0x0101ff0101ffff01, 0x0101ff0101ff01ff,
+ 0x0101ff0101ff0101, 0x0101ff0101000000, 0x0101ff010101ffff, 0x0101ff010101ff01,
+ 0x0101ff01010101ff, 0x0101ff0101010101, 0x010100ffff000100, 0x010100ffff010000,
+ 0x010100ff00ffff00, 0x010100ff00ff00ff, 0x010100ff0000ffff, 0x010100ff000000ff,
+ 0x010100ff00000000, 0x010100ff000001ff, 0x010100ff00000101, 0x010100ff0001ff00,
+ 0x010100ff00010000, 0x010100ff00010001, 0x010100ff000101ff, 0x010100ff00010100,
+ 0x010100ff01ff0000, 0x01010000ffff0001, 0x01010000ffff0100, 0x01010000ff00ffff,
+ 0x01010000ff00ff01, 0x01010000ff000000, 0x01010000ff0001ff, 0x01010000ff010001,
+ 0x01010000ff010100, 0x0101000000ffff01, 0x0101000000ff0000, 0x010100000000ff00,
+ 0x01010000000000ff, 0x0101000000000000, 0x0101000000000001, 0x0101000000000100,
+ 0x0101000000010000, 0x0101000000010101, 0x0101000001ffff00, 0x0101000001ff00ff,
+ 0x0101000001ff0000, 0x0101000001ff0001, 0x0101000001ff0100, 0x010100000100ff01,
+ 0x0101000001000000, 0x01010000010001ff, 0x01010001ffff0000, 0x01010001ff00ff00,
+ 0x01010001ff000001, 0x01010001ff000101, 0x01010001ff01ff00, 0x01010001ff010000,
+ 0x0101000100ff00ff, 0x0101000100ff0001, 0x0101000100ff0101, 0x010100010000ff01,
+ 0x0101000100000000, 0x0101000100000001, 0x01010001000001ff, 0x010100010001ffff,
+ 0x010100010001ff01, 0x0101000101ff0001, 0x010100010100ffff, 0x0101000101000000,
+ 0x0101000101000001, 0x0101000101000100, 0x010100010101ff00, 0x01010001010100ff,
+ 0x0101000101010001, 0x010101ffffffffff, 0x010101ffffffff01, 0x010101ffffff01ff,
+ 0x010101ffffff0101, 0x010101ffff01ffff, 0x010101ffff01ff01, 0x010101ffff0101ff,
+ 0x010101ffff010101, 0x010101ff0000ff00, 0x010101ff000000ff, 0x010101ff00000001,
+ 0x010101ff00000100, 0x010101ff01ffffff, 0x010101ff01ffff01, 0x010101ff01ff01ff,
+ 0x010101ff01ff0101, 0x010101ff01000000, 0x010101ff0101ffff, 0x010101ff0101ff01,
+ 0x010101ff010101ff, 0x010101ff01010101, 0x01010100ffff0000, 0x01010100ff0000ff,
+ 0x01010100ff000100, 0x01010100ff01ff00, 0x01010100ff010000, 0x0101010000ffff00,
+ 0x010101000000ffff, 0x0101010000000000, 0x0101010000000101, 0x010101000001ff00,
+ 0x0101010000010001, 0x0101010000010100, 0x010101000100ffff, 0x0101010001000001,
+ 0x01010101ffffffff, 0x01010101ffffff01, 0x01010101ffff01ff, 0x01010101ffff0101,
+ 0x01010101ff01ffff, 0x01010101ff01ff01, 0x01010101ff0101ff, 0x01010101ff010101,
+ 0x010101010000ff00, 0x01010101000000ff, 0x0101010100000001, 0x0101010101ffffff,
+ 0x0101010101ffff01, 0x0101010101ff01ff, 0x0101010101ff0101, 0x0101010101000000,
+ 0x010101010101ffff, 0x010101010101ff01, 0x01010101010101ff, 0x0101010101010101,
+GGML_TABLE_END()
+#else
+GGML_TABLE_BEGIN(uint32_t, iq1s_grid_gpu, NGRID_IQ1S)
+ 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,
+ 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,
+ 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,
+ 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212,
+ 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011,
+ 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111,
+ 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220,
+ 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022,
+ 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220,
+ 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101,
+ 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110,
+ 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111,
+ 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010,
+ 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210,
+ 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221,
+ 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021,
+ 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002,
+ 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101,
+ 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101,
+ 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211,
+ 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110,
+ 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022,
+ 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121,
+ 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220,
+ 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001,
+ 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101,
+ 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102,
+ 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012,
+ 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010,
+ 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111,
+ 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122,
+ 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222,
+ 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001,
+ 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102,
+ 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101,
+ 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000,
+ 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101,
+ 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112,
+ 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110,
+ 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211,
+ 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012,
+ 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111,
+ 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120,
+ 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122,
+ 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121,
+ 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221,
+ 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001,
+ 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101,
+ 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101,
+ 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011,
+ 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111,
+ 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011,
+ 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122,
+ 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121,
+ 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222,
+ 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101,
+ 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000,
+ 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200,
+ 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110,
+ 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112,
+ 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222,
+ 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021,
+ 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121,
+ 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201,
+ 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200,
+ 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101,
+ 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011,
+ 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010,
+ 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211,
+ 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121,
+ 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000,
+ 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202,
+ 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202,
+ 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211,
+ 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112,
+ 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020,
+ 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121,
+ 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222,
+ 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102,
+ 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100,
+ 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110,
+ 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011,
+ 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111,
+ 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110,
+ 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121,
+ 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222,
+ 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201,
+ 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102,
+ 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201,
+ 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012,
+ 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010,
+ 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010,
+ 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110,
+ 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011,
+ 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212,
+ 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021,
+ 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021,
+ 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021,
+ 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101,
+ 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101,
+ 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100,
+ 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010,
+ 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111,
+ 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010,
+ 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111,
+ 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120,
+ 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120,
+ 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101,
+ 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001,
+ 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201,
+ 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210,
+ 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211,
+ 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111,
+ 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112,
+ 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211,
+ 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010,
+ 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021,
+ 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122,
+ 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221,
+ 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102,
+ 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100,
+ 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101,
+ 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101,
+ 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101,
+ 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012,
+ 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110,
+ 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112,
+ 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210,
+ 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210,
+ 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210,
+ 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010,
+ 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110,
+ 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122,
+ 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020,
+ 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021,
+ 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022,
+ 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120,
+ 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222,
+ 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221,
+ 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001,
+ 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102,
+ 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201,
+ 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012,
+ 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111,
+ 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012,
+ 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110,
+ 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110,
+ 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121,
+ 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221,
+ 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220,
+ 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222,
+ 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000,
+ 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201,
+ 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012,
+ 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011,
+ 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212,
+ 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221,
+ 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121,
+ 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202,
+ 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202,
+ 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002,
+ 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101,
+ 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210,
+ 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112,
+ 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011,
+ 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011,
+ 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210,
+ 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020,
+ 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220,
+ 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222,
+ 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222,
+ 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001,
+ 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010,
+ 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111,
+ 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010,
+ 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110,
+ 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221,
+ 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122,
+ 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202,
+ 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100,
+ 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101,
+ 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112,
+ 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111,
+ 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211,
+ 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222,
+ 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221,
+ 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022,
+ 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101,
+ 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211,
+ 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111,
+ 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111,
+ 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010,
+ 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121,
+ 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222,
+ 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000,
+ 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202,
+ 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000,
+ 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202,
+ 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110,
+ 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110,
+ 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222,
+ 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120,
+ 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022,
+ 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101,
+ 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202,
+ 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110,
+ 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110,
+ 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111,
+ 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111,
+ 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120,
+ 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121,
+ 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001,
+ 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202,
+ 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001,
+ 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200,
+ 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011,
+ 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212,
+ 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012,
+ 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110,
+ 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012,
+ 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111,
+ 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020,
+ 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121,
+ 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222,
+ 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102,
+ 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102,
+ 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101,
+ 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212,
+ 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210,
+ 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111,
+ 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212,
+ 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221,
+ 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121,
+ 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002,
+ 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000,
+ 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202,
+ 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112,
+ 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111,
+ 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020,
+ 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221,
+ 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022,
+ 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100,
+ 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201,
+ 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112,
+ 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211,
+ 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012,
+ 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121,
+ 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020,
+ 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120,
+ 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200,
+ 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200,
+ 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110,
+ 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011,
+ 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222,
+ 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,
+ 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,
+GGML_TABLE_END()
+#endif
+
+#endif // GGML_COMMON_IMPL
+#endif // GGML_COMMON_IMPL
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
new file mode 100644
index 00000000..59cf434c
--- /dev/null
+++ b/ggml/src/ggml-cuda.cu
@@ -0,0 +1,3079 @@
+#include "ggml-cuda.h"
+#include "ggml.h"
+#include "ggml-backend-impl.h"
+
+#include "ggml-cuda/common.cuh"
+#include "ggml-cuda/acc.cuh"
+#include "ggml-cuda/arange.cuh"
+#include "ggml-cuda/argsort.cuh"
+#include "ggml-cuda/binbcast.cuh"
+#include "ggml-cuda/clamp.cuh"
+#include "ggml-cuda/concat.cuh"
+#include "ggml-cuda/convert.cuh"
+#include "ggml-cuda/cpy.cuh"
+#include "ggml-cuda/diagmask.cuh"
+#include "ggml-cuda/dmmv.cuh"
+#include "ggml-cuda/fattn.cuh"
+#include "ggml-cuda/getrows.cuh"
+#include "ggml-cuda/im2col.cuh"
+#include "ggml-cuda/mmq.cuh"
+#include "ggml-cuda/mmvq.cuh"
+#include "ggml-cuda/norm.cuh"
+#include "ggml-cuda/pad.cuh"
+#include "ggml-cuda/pool2d.cuh"
+#include "ggml-cuda/quantize.cuh"
+#include "ggml-cuda/rope.cuh"
+#include "ggml-cuda/scale.cuh"
+#include "ggml-cuda/softmax.cuh"
+#include "ggml-cuda/sumrows.cuh"
+#include "ggml-cuda/tsembd.cuh"
+#include "ggml-cuda/unary.cuh"
+#include "ggml-cuda/upscale.cuh"
+#include "ggml-cuda/conv-transpose-1d.cuh"
+
+#include <algorithm>
+#include <array>
+#include <atomic>
+#include <cinttypes>
+#include <cstddef>
+#include <cstdint>
+#include <float.h>
+#include <limits>
+#include <map>
+#include <memory>
+#include <mutex>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdarg.h>
+#include <stdlib.h>
+#include <string>
+#include <vector>
+
+static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
+
+static void ggml_cuda_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
+ GGML_UNUSED(level);
+ GGML_UNUSED(user_data);
+ fprintf(stderr, "%s", msg);
+}
+
+ggml_log_callback ggml_cuda_log_callback = ggml_cuda_default_log_callback;
+void * ggml_cuda_log_user_data = NULL;
+
+GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data) {
+ ggml_cuda_log_callback = log_callback;
+ ggml_cuda_log_user_data = user_data;
+}
+
+#define GGML_CUDA_LOG_INFO(...) ggml_cuda_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
+#define GGML_CUDA_LOG_WARN(...) ggml_cuda_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
+#define GGML_CUDA_LOG_ERROR(...) ggml_cuda_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+
+GGML_ATTRIBUTE_FORMAT(2, 3)
+static void ggml_cuda_log(enum ggml_log_level level, const char * format, ...) {
+ if (ggml_cuda_log_callback != NULL) {
+ va_list args;
+ va_start(args, format);
+ char buffer[128];
+ int len = vsnprintf(buffer, 128, format, args);
+ if (len < 128) {
+ ggml_cuda_log_callback(level, buffer, ggml_cuda_log_user_data);
+ } else {
+ std::vector<char> buffer2(len + 1); // vsnprintf adds a null terminator
+ va_end(args);
+ va_start(args, format);
+ vsnprintf(&buffer2[0], buffer2.size(), format, args);
+ ggml_cuda_log_callback(level, buffer2.data(), ggml_cuda_log_user_data);
+ }
+ va_end(args);
+ }
+}
+
+[[noreturn]]
+void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
+ int id = -1; // in case cudaGetDevice fails
+ cudaGetDevice(&id);
+
+ GGML_CUDA_LOG_ERROR("CUDA error: %s\n", msg);
+ GGML_CUDA_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line);
+ GGML_CUDA_LOG_ERROR(" %s\n", stmt);
+ // abort with GGML_ASSERT to get a stack trace
+ GGML_ASSERT(!"CUDA error");
+}
+
+// this is faster on Windows
+// probably because the Windows CUDA libraries forget to make this check before invoking the drivers
+void ggml_cuda_set_device(int device) {
+ int current_device;
+ CUDA_CHECK(cudaGetDevice(&current_device));
+
+ if (device == current_device) {
+ return;
+ }
+
+ CUDA_CHECK(cudaSetDevice(device));
+}
+
+int ggml_cuda_get_device() {
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+ return id;
+}
+
+static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
+ ggml_cuda_set_device(device);
+#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
+ auto res = hipMallocManaged(ptr, size);
+ if (res == hipSuccess) {
+ // if error we "need" to know why...
+ CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
+ }
+ return res;
+#else
+ return cudaMalloc(ptr, size);
+#endif
+}
+
+static ggml_cuda_device_info ggml_cuda_init() {
+#ifdef __HIP_PLATFORM_AMD__
+ // Workaround for a rocBLAS bug when using multiple graphics cards:
+ // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
+ rocblas_initialize();
+ CUDA_CHECK(cudaDeviceSynchronize());
+#endif
+
+ ggml_cuda_device_info info = {};
+
+ cudaError_t err = cudaGetDeviceCount(&info.device_count);
+ if (err != cudaSuccess) {
+ GGML_CUDA_LOG_ERROR("%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err));
+ return info;
+ }
+
+ GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
+
+ int64_t total_vram = 0;
+#ifdef GGML_CUDA_FORCE_MMQ
+ GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
+#else
+ GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
+#endif // GGML_CUDA_FORCE_MMQ
+#ifdef GGML_CUDA_FORCE_CUBLAS
+ GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
+#else
+ GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
+#endif // GGML_CUDA_FORCE_CUBLAS
+ GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
+ for (int id = 0; id < info.device_count; ++id) {
+ int device_vmm = 0;
+
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
+ CUdevice device;
+ CU_CHECK(cuDeviceGet(&device, id));
+ CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
+
+ if (device_vmm) {
+ CUmemAllocationProp alloc_prop = {};
+ alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
+ alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+ alloc_prop.location.id = id;
+ CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
+ }
+#endif // !defined(GGML_USE_HIPBLAS)
+ info.devices[id].vmm = !!device_vmm;
+
+ cudaDeviceProp prop;
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
+ GGML_CUDA_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
+
+ info.default_tensor_split[id] = total_vram;
+ total_vram += prop.totalGlobalMem;
+
+ info.devices[id].nsm = prop.multiProcessorCount;
+ info.devices[id].smpb = prop.sharedMemPerBlock;
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+ info.devices[id].smpbo = prop.sharedMemPerBlock;
+ info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
+#else
+ info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
+ info.devices[id].cc = 100*prop.major + 10*prop.minor;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+ }
+
+ for (int id = 0; id < info.device_count; ++id) {
+ info.default_tensor_split[id] /= total_vram;
+ }
+
+ // configure logging to stdout
+ // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
+
+ return info;
+}
+
+const ggml_cuda_device_info & ggml_cuda_info() {
+ static ggml_cuda_device_info info = ggml_cuda_init();
+ return info;
+}
+
+// #define DEBUG_CUDA_MALLOC
+
+// buffer pool for cuda (legacy)
+struct ggml_cuda_pool_leg : public ggml_cuda_pool {
+ static const int MAX_BUFFERS = 256;
+
+ int device;
+ struct ggml_cuda_buffer {
+ void * ptr = nullptr;
+ size_t size = 0;
+ };
+
+ ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
+ size_t pool_size = 0;
+
+ explicit ggml_cuda_pool_leg(int device) :
+ device(device) {
+ }
+
+ ~ggml_cuda_pool_leg() {
+ ggml_cuda_set_device(device);
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
+ ggml_cuda_buffer & b = buffer_pool[i];
+ if (b.ptr != nullptr) {
+ CUDA_CHECK(cudaFree(b.ptr));
+ pool_size -= b.size;
+ }
+ }
+ GGML_ASSERT(pool_size == 0);
+ }
+
+ void * alloc(size_t size, size_t * actual_size) override {
+#ifdef DEBUG_CUDA_MALLOC
+ int nnz = 0;
+ size_t max_size = 0;
+#endif
+ size_t best_diff = 1ull << 36;
+ int ibest = -1;
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
+ ggml_cuda_buffer& b = buffer_pool[i];
+ if (b.ptr != nullptr) {
+#ifdef DEBUG_CUDA_MALLOC
+ ++nnz;
+ if (b.size > max_size) max_size = b.size;
+#endif
+ if (b.size >= size) {
+ size_t diff = b.size - size;
+ if (diff < best_diff) {
+ best_diff = diff;
+ ibest = i;
+ if (!best_diff) {
+ void * ptr = b.ptr;
+ *actual_size = b.size;
+ b.ptr = nullptr;
+ b.size = 0;
+ return ptr;
+ }
+ }
+ }
+ }
+ }
+ if (ibest >= 0) {
+ ggml_cuda_buffer& b = buffer_pool[ibest];
+ void * ptr = b.ptr;
+ *actual_size = b.size;
+ b.ptr = nullptr;
+ b.size = 0;
+ return ptr;
+ }
+ void * ptr;
+ size_t look_ahead_size = (size_t) (1.05 * size);
+ look_ahead_size = 256 * ((look_ahead_size + 255)/256);
+ ggml_cuda_set_device(device);
+ CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
+ *actual_size = look_ahead_size;
+ pool_size += look_ahead_size;
+#ifdef DEBUG_CUDA_MALLOC
+ GGML_CUDA_LOG_INFO("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz,
+ (uint32_t)(max_size / 1024 / 1024), (uint32_t)(pool_size / 1024 / 1024), (uint32_t)(size / 1024 / 1024));
+#endif
+ return ptr;
+ }
+
+ void free(void * ptr, size_t size) override {
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
+ ggml_cuda_buffer& b = buffer_pool[i];
+ if (b.ptr == nullptr) {
+ b.ptr = ptr;
+ b.size = size;
+ return;
+ }
+ }
+ GGML_CUDA_LOG_WARN("Cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
+ ggml_cuda_set_device(device);
+ CUDA_CHECK(cudaFree(ptr));
+ pool_size -= size;
+ }
+};
+
+// pool with virtual memory
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
+struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
+ static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
+
+ int device;
+ CUdeviceptr pool_addr = 0;
+ size_t pool_used = 0;
+ size_t pool_size = 0;
+ size_t granularity;
+
+ explicit ggml_cuda_pool_vmm(int device) :
+ device(device),
+ granularity(ggml_cuda_info().devices[device].vmm_granularity) {
+ }
+
+ ~ggml_cuda_pool_vmm() {
+ if (pool_addr != 0) {
+ CU_CHECK(cuMemUnmap(pool_addr, pool_size));
+ CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE));
+ }
+ }
+
+ void * alloc(size_t size, size_t * actual_size) override {
+ // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
+ const size_t alignment = 128;
+ size = alignment * ((size + alignment - 1) / alignment);
+
+ size_t avail = pool_size - pool_used;
+
+ if (size > avail) {
+ // round up to the next multiple of the granularity
+ size_t reserve_size = size - avail;
+ reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
+
+ GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
+
+ // allocate more physical memory
+ CUmemAllocationProp prop = {};
+ prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
+ prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+ prop.location.id = device;
+ CUmemGenericAllocationHandle handle;
+ CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
+
+ // reserve virtual address space (if not already reserved)
+ if (pool_addr == 0) {
+ CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
+ }
+
+ // map at the end of the pool
+ CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0));
+
+ // the memory allocation handle is no longer needed after mapping
+ CU_CHECK(cuMemRelease(handle));
+
+ // set access
+ CUmemAccessDesc access = {};
+ access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+ access.location.id = device;
+ access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
+ CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1));
+
+ // add to the pool
+ pool_size += reserve_size;
+
+ //printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
+ // device, (unsigned long long) (pool_size/1024/1024),
+ // (unsigned long long) (reserve_size/1024/1024));
+ }
+
+ GGML_ASSERT(pool_addr != 0);
+
+ void * ptr = (void *) (pool_addr + pool_used);
+ *actual_size = size;
+ pool_used += size;
+
+#ifdef DEBUG_CUDA_MALLOC
+ printf("cuda pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
+#endif
+
+ return ptr;
+ }
+
+ void free(void * ptr, size_t size) override {
+#ifdef DEBUG_CUDA_MALLOC
+ printf("cuda pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
+#endif
+
+ pool_used -= size;
+
+ // all deallocations must be in reverse order of the allocations
+ GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
+ }
+};
+#endif // !defined(GGML_USE_HIPBLAS)
+
+std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
+ if (ggml_cuda_info().devices[device].vmm) {
+ return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
+ }
+#endif
+ return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
+}
+
+// cuda buffer
+
+struct ggml_backend_cuda_buffer_context {
+ int device;
+ void * dev_ptr = nullptr;
+ std::string name;
+
+ ggml_backend_cuda_buffer_context(int device, void * dev_ptr) :
+ device(device), dev_ptr(dev_ptr),
+ name(GGML_CUDA_NAME + std::to_string(device)) {
+ }
+
+ ~ggml_backend_cuda_buffer_context() {
+ CUDA_CHECK(cudaFree(dev_ptr));
+ }
+};
+
+GGML_CALL static const char * ggml_backend_cuda_buffer_get_name(ggml_backend_buffer_t buffer) {
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+ return ctx->name.c_str();
+}
+
+GGML_CALL static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
+ return buffer->iface.get_name == ggml_backend_cuda_buffer_get_name;
+}
+
+GGML_CALL static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+ delete ctx;
+}
+
+GGML_CALL static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+ return ctx->dev_ptr;
+}
+
+GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+ if (tensor->view_src != NULL) {
+ assert(tensor->view_src->buffer->buft == buffer->buft);
+ return;
+ }
+
+ if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
+ // initialize padding to 0 to avoid possible NaN values
+ size_t original_size = ggml_nbytes(tensor);
+ size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
+
+ if (padded_size > original_size) {
+ ggml_cuda_set_device(ctx->device);
+ CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size));
+ }
+ }
+}
+
+GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+ ggml_cuda_set_device(ctx->device);
+ CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+}
+
+GGML_CALL static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+ ggml_cuda_set_device(ctx->device);
+ CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+}
+
+GGML_CALL static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
+ if (ggml_backend_buffer_is_cuda(src->buffer)) {
+ ggml_backend_cuda_buffer_context * src_ctx = (ggml_backend_cuda_buffer_context *)src->buffer->context;
+ ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *)dst->buffer->context;
+ if (src_ctx->device == dst_ctx->device) {
+ CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(src), cudaMemcpyDeviceToDevice, cudaStreamPerThread));
+ } else {
+#ifdef GGML_CUDA_NO_PEER_COPY
+ return false;
+#else
+ CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, dst_ctx->device, src->data, src_ctx->device, ggml_nbytes(src), cudaStreamPerThread));
+#endif
+ }
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+ return true;
+ }
+ return false;
+
+ GGML_UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+ ggml_cuda_set_device(ctx->device);
+ CUDA_CHECK(cudaDeviceSynchronize());
+ CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size));
+ CUDA_CHECK(cudaDeviceSynchronize());
+}
+
+static ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
+ /* .get_name = */ ggml_backend_cuda_buffer_get_name,
+ /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_cuda_buffer_get_base,
+ /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
+ /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
+ /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor,
+ /* .clear = */ ggml_backend_cuda_buffer_clear,
+ /* .reset = */ NULL,
+};
+
+// cuda buffer type
+struct ggml_backend_cuda_buffer_type_context {
+ int device;
+ std::string name;
+};
+
+GGML_CALL static const char * ggml_backend_cuda_buffer_type_name(ggml_backend_buffer_type_t buft) {
+ ggml_backend_cuda_buffer_type_context * ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
+
+ return ctx->name.c_str();
+}
+
+static bool ggml_backend_buft_is_cuda(ggml_backend_buffer_type_t buft) {
+ return buft->iface.get_name == ggml_backend_cuda_buffer_type_name;
+}
+
+GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
+
+ ggml_cuda_set_device(buft_ctx->device);
+
+ size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
+
+ void * dev_ptr;
+ cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
+ if (err != cudaSuccess) {
+ // clear the error
+ cudaGetLastError();
+ GGML_CUDA_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
+ return nullptr;
+ }
+
+ ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
+
+ return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
+}
+
+GGML_CALL static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+ return 128;
+
+ GGML_UNUSED(buft);
+}
+
+GGML_CALL static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+ size_t size = ggml_nbytes(tensor);
+ int64_t ne0 = tensor->ne[0];
+
+ if (ggml_is_quantized(tensor->type)) {
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+ }
+
+ return size;
+
+ GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
+ /* .get_name = */ ggml_backend_cuda_buffer_type_name,
+ /* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment,
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
+ /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size,
+ /* .is_host = */ NULL,
+};
+
+GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+
+ if (device >= ggml_backend_cuda_get_device_count()) {
+ return nullptr;
+ }
+
+ static ggml_backend_buffer_type ggml_backend_cuda_buffer_types[GGML_CUDA_MAX_DEVICES];
+
+ static bool ggml_backend_cuda_buffer_type_initialized = false;
+
+ if (!ggml_backend_cuda_buffer_type_initialized) {
+ for (int i = 0; i < GGML_CUDA_MAX_DEVICES; i++) {
+ ggml_backend_cuda_buffer_types[i] = {
+ /* .iface = */ ggml_backend_cuda_buffer_type_interface,
+ /* .context = */ new ggml_backend_cuda_buffer_type_context{i, GGML_CUDA_NAME + std::to_string(i)},
+ };
+ }
+ ggml_backend_cuda_buffer_type_initialized = true;
+ }
+
+ return &ggml_backend_cuda_buffer_types[device];
+}
+
+// cuda split buffer
+
+static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
+ int64_t row_rounding = 0;
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
+ continue;
+ }
+
+ const int cc = ggml_cuda_info().devices[id].cc;
+ row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
+ }
+ return row_rounding;
+}
+
+static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {
+ const int64_t nrows = ggml_nrows(tensor);
+ const int64_t rounding = get_row_rounding(tensor_split);
+
+ *row_low = id == 0 ? 0 : nrows*tensor_split[id];
+ *row_low -= *row_low % rounding;
+
+ if (id == ggml_backend_cuda_get_device_count() - 1) {
+ *row_high = nrows;
+ } else {
+ *row_high = nrows*tensor_split[id + 1];
+ *row_high -= *row_high % rounding;
+ }
+}
+
+static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
+}
+
+struct ggml_backend_cuda_split_buffer_type_context {
+ std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
+};
+
+struct ggml_backend_cuda_split_buffer_context {
+ ~ggml_backend_cuda_split_buffer_context() {
+ for (ggml_tensor_extra_gpu * extra : tensor_extras) {
+ for (int id = 0; id < GGML_CUDA_MAX_DEVICES; ++id) {
+ for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {
+ if (extra->events[id][is] != nullptr) {
+ CUDA_CHECK(cudaEventDestroy(extra->events[id][is]));
+ }
+ }
+ if (extra->data_device[id] != nullptr) {
+ CUDA_CHECK(cudaFree(extra->data_device[id]));
+ }
+ }
+ delete extra;
+ }
+ }
+
+ std::vector<ggml_tensor_extra_gpu *> tensor_extras;
+};
+
+GGML_CALL static const char * ggml_backend_cuda_split_buffer_get_name(ggml_backend_buffer_t buffer) {
+ return GGML_CUDA_NAME "_Split";
+
+ GGML_UNUSED(buffer);
+}
+
+static bool ggml_backend_buffer_is_cuda_split(ggml_backend_buffer_t buffer) {
+ return buffer->iface.get_name == ggml_backend_cuda_split_buffer_get_name;
+ GGML_UNUSED(ggml_backend_buffer_is_cuda_split); // only used in debug builds currently, avoid unused function warning in release builds
+}
+
+GGML_CALL static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
+ delete ctx;
+}
+
+GGML_CALL static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
+ // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
+ return (void *)0x1000;
+
+ GGML_UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+ GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
+
+ ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
+ ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
+
+ const int64_t ne0 = tensor->ne[0];
+
+ ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
+ ctx->tensor_extras.push_back(extra);
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ int64_t row_low, row_high;
+ get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
+
+ int64_t nrows_split = row_high - row_low;
+ if (nrows_split == 0) {
+ continue;
+ }
+
+ size_t size = ggml_nbytes_split(tensor, nrows_split);
+ const size_t original_size = size;
+
+ // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+
+ // FIXME: do not crash if cudaMalloc fails
+ // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
+ ggml_cuda_set_device(id);
+ char * buf;
+ CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
+
+ // set padding to 0 to avoid possible NaN values
+ if (size > original_size) {
+ CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size));
+ }
+
+ extra->data_device[id] = buf;
+
+ for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {
+ CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming));
+ }
+ }
+ tensor->extra = extra;
+}
+
+GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ // split tensors must always be set in their entirety at once
+ GGML_ASSERT(offset == 0);
+ GGML_ASSERT(size == ggml_nbytes(tensor));
+
+ ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
+
+ const int64_t ne0 = tensor->ne[0];
+ const size_t nb1 = tensor->nb[1];
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ int64_t row_low, row_high;
+ get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
+
+ int64_t nrows_split = row_high - row_low;
+ if (nrows_split == 0) {
+ continue;
+ }
+
+ const size_t offset_split = row_low*nb1;
+ size_t size = ggml_nbytes_split(tensor, nrows_split);
+ const size_t original_size = size;
+
+ // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+
+ const char * buf_host = (const char *)data + offset_split;
+ CUDA_CHECK(cudaMemcpyAsync(extra->data_device[id], buf_host, original_size, cudaMemcpyHostToDevice, cudaStreamPerThread));
+ }
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+ }
+}
+
+GGML_CALL static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ // split tensors must always be set in their entirety at once
+ GGML_ASSERT(offset == 0);
+ GGML_ASSERT(size == ggml_nbytes(tensor));
+
+ ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
+
+ const int64_t ne0 = tensor->ne[0];
+ const size_t nb1 = tensor->nb[1];
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ int64_t row_low, row_high;
+ get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
+
+ int64_t nrows_split = row_high - row_low;
+ if (nrows_split == 0) {
+ continue;
+ }
+
+ const size_t offset_split = row_low*nb1;
+ size_t size = ggml_nbytes_split(tensor, nrows_split);
+ const size_t original_size = size;
+
+ // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+
+ char * buf_host = (char *)data + offset_split;
+ CUDA_CHECK(cudaMemcpyAsync(buf_host, extra->data_device[id], original_size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
+ }
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+ }
+}
+
+GGML_CALL static void ggml_backend_cuda_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ GGML_UNUSED(buffer);
+ GGML_UNUSED(value);
+}
+
+static struct ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
+ /* .get_name = */ ggml_backend_cuda_split_buffer_get_name,
+ /* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_cuda_split_buffer_get_base,
+ /* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor,
+ /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor,
+ /* .cpy_tensor = */ NULL,
+ /* .clear = */ ggml_backend_cuda_split_buffer_clear,
+ /* .reset = */ NULL,
+};
+
+// cuda split buffer type
+
+GGML_CALL static const char * ggml_backend_cuda_split_buffer_type_name(ggml_backend_buffer_type_t buft) {
+ return GGML_CUDA_NAME "_Split";
+
+ GGML_UNUSED(buft);
+}
+
+static bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) {
+ return buft->iface.get_name == ggml_backend_cuda_split_buffer_type_name;
+}
+
+GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
+ // instead, we allocate them for each tensor separately in init_tensor
+ // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
+ // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
+ ggml_backend_cuda_split_buffer_context * ctx = new ggml_backend_cuda_split_buffer_context();
+
+ return ggml_backend_buffer_init(buft, ggml_backend_cuda_split_buffer_interface, ctx, size);
+}
+
+GGML_CALL static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+ return 128;
+
+ GGML_UNUSED(buft);
+}
+
+GGML_CALL static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+ ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
+
+ size_t total_size = 0;
+
+ const int64_t ne0 = tensor->ne[0];
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ int64_t row_low, row_high;
+ get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, id);
+
+ int64_t nrows_split = row_high - row_low;
+ if (nrows_split == 0) {
+ continue;
+ }
+
+ total_size += ggml_nbytes_split(tensor, nrows_split);
+
+ // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+ }
+
+ return total_size;
+}
+
+GGML_CALL static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+ return false;
+
+ GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface = {
+ /* .get_name = */ ggml_backend_cuda_split_buffer_type_name,
+ /* .alloc_buffer = */ ggml_backend_cuda_split_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cuda_split_buffer_type_get_alignment,
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
+ /* .get_alloc_size = */ ggml_backend_cuda_split_buffer_type_get_alloc_size,
+ /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host,
+};
+
+GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split) {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+
+ static std::map<std::array<float, GGML_CUDA_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
+
+ std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split_arr = {};
+
+ bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_CUDA_MAX_DEVICES, [](float x) { return x == 0.0f; });
+ if (all_zero) {
+ tensor_split_arr = ggml_cuda_info().default_tensor_split;
+ } else {
+ float split_sum = 0.0f;
+ for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
+ tensor_split_arr[i] = split_sum;
+ split_sum += tensor_split[i];
+ }
+ for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
+ tensor_split_arr[i] /= split_sum;
+ }
+ }
+
+ auto it = buft_map.find(tensor_split_arr);
+ if (it != buft_map.end()) {
+ return &it->second;
+ }
+
+ struct ggml_backend_buffer_type buft {
+ /* .iface = */ ggml_backend_cuda_split_buffer_type_interface,
+ /* .context = */ new ggml_backend_cuda_split_buffer_type_context{tensor_split_arr},
+ };
+
+ auto result = buft_map.emplace(tensor_split_arr, buft);
+ return &result.first->second;
+}
+
+// host buffer type
+
+GGML_CALL static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
+ return GGML_CUDA_NAME "_Host";
+
+ GGML_UNUSED(buft);
+}
+
+GGML_CALL static const char * ggml_backend_cuda_host_buffer_name(ggml_backend_buffer_t buffer) {
+ return GGML_CUDA_NAME "_Host";
+
+ GGML_UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ CUDA_CHECK(cudaFreeHost(buffer->context));
+}
+
+static void * ggml_cuda_host_malloc(size_t size) {
+ if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
+ return nullptr;
+ }
+
+ void * ptr = nullptr;
+ cudaError_t err = cudaMallocHost((void **) &ptr, size);
+ if (err != cudaSuccess) {
+ // clear the error
+ cudaGetLastError();
+ GGML_CUDA_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
+ size / 1024.0 / 1024.0, cudaGetErrorString(err));
+ return nullptr;
+ }
+
+ return ptr;
+}
+
+GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ void * ptr = ggml_cuda_host_malloc(size);
+
+ if (ptr == nullptr) {
+ // fallback to cpu buffer
+ return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
+ }
+
+ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
+ buffer->buft = buft;
+ buffer->iface.get_name = ggml_backend_cuda_host_buffer_name;
+ buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
+
+ return buffer;
+}
+
+GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {
+ static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_type_host = {
+ /* .iface = */ {
+ /* .get_name = */ ggml_backend_cuda_host_buffer_type_name,
+ /* .alloc_buffer = */ ggml_backend_cuda_host_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
+ /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
+ /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
+ },
+ /* .context = */ nullptr,
+ };
+
+ return &ggml_backend_cuda_buffer_type_host;
+}
+
+//static bool ggml_backend_buffer_is_cuda_host(ggml_backend_buffer_t buffer) {
+// return buffer->buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
+//}
+
+/// kernels
+
+typedef void (*ggml_cuda_op_mul_mat_t)(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream);
+
+#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE
+#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
+#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
+
+#define MUL_MAT_SRC1_COL_STRIDE 128
+
+static __global__ void mul_mat_p021_f16_f32(
+ const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {
+
+ const half * x = (const half *) vx;
+
+ const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
+ const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+ const int channel_x = channel / (nchannels_y / nchannels_x);
+
+ const int nrows_y = ncols_x;
+ const int nrows_dst = nrows_x;
+ const int row_dst = row_x;
+
+ float tmp = 0.0f;
+
+ for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
+ const int col_x = col_x0 + threadIdx.x;
+
+ if (col_x >= ncols_x) {
+ break;
+ }
+
+ // x is transposed and permuted
+ const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
+ const float xi = __half2float(x[ix]);
+
+ const int row_y = col_x;
+
+ // y is not transposed but permuted
+ const int iy = channel*nrows_y + row_y;
+
+ tmp += xi * y[iy];
+ }
+
+ // dst is not transposed and not permuted
+ const int idst = channel*nrows_dst + row_dst;
+
+ // sum up partial sums and write back result
+ tmp = warp_reduce_sum(tmp);
+
+ if (threadIdx.x == 0) {
+ dst[idst] = tmp;
+ }
+}
+
+static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
+ const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
+ const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {
+
+ const half * x = (const half *) vx;
+
+ const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
+ const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+ const int channel_x = channel / channel_x_divisor;
+
+ const int nrows_y = ncols_x;
+ const int nrows_dst = nrows_x;
+ const int row_dst = row_x;
+
+ const int idst = channel*nrows_dst + row_dst;
+
+ float tmp = 0.0f;
+
+ for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
+ const int col_x = col_x0 + threadIdx.x;
+
+ if (col_x >= ncols_x) {
+ break;
+ }
+
+ const int row_y = col_x;
+
+ const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
+ const int iy = channel*nrows_y + row_y;
+
+ const float xi = __half2float(x[ix]);
+
+ tmp += xi * y[iy];
+ }
+
+ // sum up partial sums and write back result
+ tmp = warp_reduce_sum(tmp);
+
+ if (threadIdx.x == 0) {
+ dst[idst] = tmp;
+ }
+}
+
+static void ggml_mul_mat_p021_f16_f32_cuda(
+ const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
+ const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
+
+ const dim3 block_nums(1, nrows_x, nchannels_y);
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
+}
+
+static void ggml_mul_mat_vec_nc_f16_f32_cuda(
+ const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
+ const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {
+
+ const dim3 block_nums(1, nrows_x, nchannels_y);
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
+ (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
+}
+
+static cudaError_t ggml_cuda_cpy_tensor_2d(
+ void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
+
+ GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
+ char * src_ptr = (char *) src->data;
+ char * dst_ptr = (char *) dst;
+
+ const int64_t ne0 = src->ne[0];
+ const int64_t nb0 = src->nb[0];
+ const int64_t nb1 = src->nb[1];
+ const int64_t nb2 = src->nb[2];
+ const int64_t nb3 = src->nb[3];
+ const enum ggml_type type = src->type;
+ const int64_t ts = ggml_type_size(type);
+ const int64_t bs = ggml_blck_size(type);
+ int64_t i1_diff = i1_high - i1_low;
+
+ const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
+ if (nb0 == ts && nb1 == ts*ne0/bs) {
+ return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyDeviceToDevice, stream);
+ } else if (nb0 == ts) {
+ return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyDeviceToDevice, stream);
+ } else {
+ for (int64_t i1 = 0; i1 < i1_diff; i1++) {
+ const void * rx = (const void *) ((const char *) x + i1*nb1);
+ void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
+ // pretend the row is a matrix with cols=1
+ cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyDeviceToDevice, stream);
+ if (r != cudaSuccess) {
+ return r;
+ }
+ }
+ return cudaSuccess;
+ }
+}
+
+static void ggml_cuda_op_mul_mat_cublas(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+ GGML_ASSERT(src0_dd_i != nullptr);
+ GGML_ASSERT(src1_ddf_i != nullptr);
+ GGML_ASSERT(dst_dd_i != nullptr);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne10 = src1->ne[0];
+
+ const int64_t ne0 = dst->ne[0];
+
+ const int64_t row_diff = row_high - row_low;
+
+ int id = ggml_cuda_get_device();
+
+ // the main device has a larger memory buffer to hold the results from all GPUs
+ // ldc == nrows of the matrix that cuBLAS writes into
+ int64_t ldc = id == ctx.device ? ne0 : row_diff;
+
+ const int compute_capability = ggml_cuda_info().devices[id].cc;
+
+ if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
+ // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
+ ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
+ if (src0->type != GGML_TYPE_F16) {
+ const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
+ GGML_ASSERT(to_fp16_cuda != nullptr);
+ size_t ne = row_diff*ne00;
+ src0_as_f16.alloc(ne);
+ to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream);
+ }
+ const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
+
+ ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
+ if (src1->type != GGML_TYPE_F16) {
+ const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+ GGML_ASSERT(to_fp16_cuda != nullptr);
+ size_t ne = src1_ncols*ne10;
+ src1_as_f16.alloc(ne);
+ to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
+ }
+ const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
+ ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
+
+ const half alpha_f16 = 1.0f;
+ const half beta_f16 = 0.0f;
+
+ CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
+ CUBLAS_CHECK(
+ cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+ row_diff, src1_ncols, ne10,
+ &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
+ src1_ptr, CUDA_R_16F, ne10,
+ &beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
+ CUBLAS_COMPUTE_16F,
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+ to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+ } else {
+ ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
+ ggml_cuda_pool_alloc<float> src1_ddq_as_f32(ctx.pool(id));
+
+ if (src0->type != GGML_TYPE_F32) {
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
+ GGML_ASSERT(to_fp32_cuda != nullptr);
+ src0_ddq_as_f32.alloc(row_diff*ne00);
+ to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
+ }
+ if (src1->type != GGML_TYPE_F32) {
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
+ GGML_ASSERT(to_fp32_cuda != nullptr);
+ src1_ddq_as_f32.alloc(src1_ncols*ne10);
+ to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
+ }
+
+ const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
+ const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
+
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+
+ CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
+ CUBLAS_CHECK(
+ cublasSgemm(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+ row_diff, src1_ncols, ne10,
+ &alpha, src0_ddf_i, ne00,
+ src1_ddf1_i, ne10,
+ &beta, dst_dd_i, ldc));
+ }
+
+ GGML_UNUSED(dst);
+ GGML_UNUSED(src1_ddq_i);
+ GGML_UNUSED(src1_padded_row_size);
+}
+
+static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
+ static bool peer_access_enabled = false;
+
+ const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
+
+ if (peer_access_enabled == enable_peer_access) {
+ return;
+ }
+
+#ifdef NDEBUG
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ ggml_cuda_set_device(id);
+ CUDA_CHECK(cudaDeviceSynchronize());
+ }
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ ggml_cuda_set_device(id);
+
+ for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) {
+ if (id == id_other) {
+ continue;
+ }
+ if (id != main_device && id_other != main_device) {
+ continue;
+ }
+
+ int can_access_peer;
+ CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
+ if (can_access_peer) {
+ if (enable_peer_access) {
+ cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0);
+ if (err != cudaErrorPeerAccessAlreadyEnabled) {
+ CUDA_CHECK(err);
+ }
+ } else {
+ cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
+ if (err != cudaErrorPeerAccessNotEnabled) {
+ CUDA_CHECK(err);
+ }
+ }
+ }
+ }
+ }
+
+ ggml_cuda_set_device(main_device);
+#endif // NDEBUG
+
+ peer_access_enabled = enable_peer_access;
+
+ GGML_UNUSED(main_device);
+}
+
+static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
+ void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
+
+#if !defined(GGML_USE_HIPBLAS)
+ // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
+ cudaMemcpy3DPeerParms p = {};
+ p.dstDevice = dstDevice;
+ p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height);
+ p.srcDevice = srcDevice;
+ p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height);
+ p.extent = make_cudaExtent(width, height, 1);
+ return cudaMemcpy3DPeerAsync(&p, stream);
+#else
+ // HIP does not support cudaMemcpy3DPeerAsync or vmm pools
+ GGML_UNUSED(dstDevice);
+ GGML_UNUSED(srcDevice);
+ return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
+#endif // !defined(GGML_USE_HIPBLAS)
+}
+
+static void ggml_cuda_op_mul_mat(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
+ quantize_cuda_t quantize_src1) {
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+ const int64_t ne03 = src0->ne[3];
+
+ const int64_t ne10 = src1->ne[0];
+ const int64_t ne11 = src1->ne[1];
+ const int64_t ne12 = src1->ne[2];
+ const int64_t ne13 = src1->ne[3];
+ const int64_t nrows1 = ggml_nrows(src1);
+
+ GGML_ASSERT(ne03 == ne13);
+
+ const int64_t ne0 = dst->ne[0];
+ const int64_t ne1 = dst->ne[1];
+
+ const int64_t nb2 = dst->nb[2];
+ const int64_t nb3 = dst->nb[3];
+
+ GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
+ GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
+ ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
+ ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
+
+ GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
+
+ GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
+
+ const int64_t i02_divisor = ne12 / ne02;
+
+ const size_t src0_ts = ggml_type_size(src0->type);
+ const size_t src0_bs = ggml_blck_size(src0->type);
+ const size_t q8_1_ts = sizeof(block_q8_1);
+ const size_t q8_1_bs = QK8_1;
+
+ const bool src0_is_contiguous = ggml_is_contiguous(src0);
+ const bool src1_is_contiguous = ggml_is_contiguous(src1);
+
+ const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
+
+ const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
+ GGML_ASSERT(!(split && ne02 > 1));
+ GGML_ASSERT(!(split && ne03 > 1));
+ GGML_ASSERT(!(split && ne02 < ne12));
+
+ ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr;
+
+
+ std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
+ if (split) {
+ ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
+ tensor_split = buft_ctx->tensor_split;
+ }
+
+ struct dev_data {
+ int cc;
+
+ ggml_cuda_pool_alloc<char> src0_dd_alloc;
+ ggml_cuda_pool_alloc<float> src1_ddf_alloc;
+ ggml_cuda_pool_alloc<char> src1_ddq_alloc;
+ ggml_cuda_pool_alloc<float> dst_dd_alloc;
+
+ char * src0_dd = nullptr;
+ float * src1_ddf = nullptr; // float
+ char * src1_ddq = nullptr; // q8_1
+ float * dst_dd = nullptr;
+
+ int64_t row_low;
+ int64_t row_high;
+ };
+
+ dev_data dev[GGML_CUDA_MAX_DEVICES];
+
+ int used_devices = 0;
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ dev[id].cc = ggml_cuda_info().devices[id].cc;
+
+ // by default, use all rows
+ dev[id].row_low = 0;
+ dev[id].row_high = ne01;
+
+ // for multi GPU, get the row boundaries from tensor split
+ // and round to mul_mat_q tile sizes
+ if (split) {
+ const int64_t rounding = get_row_rounding(tensor_split);
+
+ if (id != 0) {
+ dev[id].row_low = ne01*tensor_split[id];
+ if (dev[id].row_low < ne01) {
+ dev[id].row_low -= dev[id].row_low % rounding;
+ }
+ }
+
+ if (id != ggml_backend_cuda_get_device_count() - 1) {
+ dev[id].row_high = ne01*tensor_split[id + 1];
+ if (dev[id].row_high < ne01) {
+ dev[id].row_high -= dev[id].row_high % rounding;
+ }
+ }
+ }
+ }
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
+ continue;
+ }
+
+ used_devices++;
+
+ const bool src1_on_device = id == src1_ctx->device;
+ const bool dst_on_device = id == dst_ctx->device;
+
+ ggml_cuda_set_device(id);
+ cudaStream_t stream = ctx.stream(id, 0);
+
+ if (src0_is_contiguous) {
+ dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
+ } else {
+ dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0));
+ }
+
+ // If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared:
+ if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
+ const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
+ const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
+ CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
+ }
+
+ if (src1_on_device && src1_is_contiguous) {
+ dev[id].src1_ddf = (float *) src1->data;
+ } else {
+ dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1));
+ }
+
+ if (quantize_src1) {
+ size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
+ if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+ src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq);
+ }
+ dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
+
+ if (src1_on_device && src1_is_contiguous) {
+ quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
+ CUDA_CHECK(cudaGetLastError());
+ }
+ }
+
+ if (dst_on_device) {
+ dev[id].dst_dd = (float *) dst->data;
+ } else {
+ const size_t size_dst_ddf = split ? (dev[id].row_high - dev[id].row_low)*ne1 : ggml_nelements(dst);
+ dev[id].dst_dd = dev[id].dst_dd_alloc.alloc(ctx.pool(id), size_dst_ddf);
+ }
+ }
+
+ // if multiple devices are used they need to wait for the main device
+ // here an event is recorded that signals that the main device has finished calculating the input data
+ if (split && used_devices > 1) {
+ ggml_cuda_set_device(ctx.device);
+ CUDA_CHECK(cudaEventRecord(src0_extra->events[ctx.device][0], ctx.stream()));
+ }
+
+ const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
+ for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
+ const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_CUDA_MAX_STREAMS : 0;
+ const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
+ continue;
+ }
+
+ const bool src1_on_device = id == src1_ctx->device;
+ const bool dst_on_device = id == dst_ctx->device;
+ const int64_t row_diff = dev[id].row_high - dev[id].row_low;
+
+ ggml_cuda_set_device(id);
+ cudaStream_t stream = ctx.stream(id, is);
+
+ // wait for main GPU data if necessary
+ if (split && (id != ctx.device || is != 0)) {
+ CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[ctx.device][0], 0));
+ }
+
+ for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
+ const int64_t i03 = i0 / ne12;
+ const int64_t i02 = i0 % ne12;
+
+ size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs;
+ if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+ src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq);
+ } else {
+ src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs;
+ }
+
+ // for split tensors the data begins at i0 == i0_offset_low
+ char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
+ float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
+ char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset;
+ float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
+
+ // the main device memory buffer can be on VRAM scratch, with space for all partial results
+ // in that case an offset on dst_ddf_i is needed
+ if (id == ctx.device) {
+ dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split
+ }
+
+ // copy src0, src1 to device if necessary
+ if (src1_is_contiguous) {
+ if (id != ctx.device) {
+ if (quantize_src1) {
+ char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
+ if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+ const size_t pitch = ne11*sizeof(block_q8_1_mmq);
+ const size_t width = src1_ncols*sizeof(block_q8_1_mmq);
+ const size_t height = src1_padded_col_size/(4*QK8_1);
+ CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream));
+ } else {
+ CUDA_CHECK(cudaMemcpyPeerAsync(
+ src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
+ }
+ } else {
+ float * src1_ddf_i_source = (float *) src1->data;
+ src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
+ CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddf_i, id, src1_ddf_i_source, ctx.device,
+ src1_ncols*ne10*sizeof(float), stream));
+ }
+ }
+ } else if (src1_on_device && !src1_is_contiguous) {
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
+ src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ if (quantize_src1 && !src1_is_contiguous) {
+ quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
+ CUDA_CHECK(cudaGetLastError());
+ }
+
+ if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
+ }
+
+ // do the computation
+ op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
+ dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream);
+ CUDA_CHECK(cudaGetLastError());
+
+ // copy dst to host or other device if necessary
+ if (!dst_on_device) {
+ void * dst_off_device = dst->data;
+ if (split) {
+ // src0 = weight matrix is saved as a transposed matrix for better memory layout.
+ // dst is NOT transposed.
+ // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
+ // Instead they need to be copied to the correct slice in ne0 = dst row index.
+ // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
+ float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+ GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+ dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
+ CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(
+ dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream));
+ } else {
+ float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+ GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+ dhf_dst_i += src1_col_0*ne0;
+ CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), cudaMemcpyDeviceToDevice, stream));
+ }
+ }
+
+ // add event for the main device to wait on until other device is done
+ if (split && (id != ctx.device || is != 0)) {
+ CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream));
+ }
+ }
+ }
+ }
+
+ // main device waits for all other devices to be finished
+ if (split && ggml_backend_cuda_get_device_count() > 1) {
+ int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
+ is_max = is_max <= GGML_CUDA_MAX_STREAMS ? is_max : GGML_CUDA_MAX_STREAMS;
+
+ ggml_cuda_set_device(ctx.device);
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ if (dev[id].row_low == dev[id].row_high) {
+ continue;
+ }
+ for (int64_t is = 0; is < is_max; ++is) {
+ CUDA_CHECK(cudaStreamWaitEvent(ctx.stream(), src0_extra->events[id][is], 0));
+ }
+ }
+ }
+}
+
+static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
+ GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
+ GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
+ GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+
+ const int64_t ne12 = src1->ne[2];
+
+ cudaStream_t main_stream = ctx.stream();
+
+ void * src0_ddq = src0->data;
+ float * src1_ddf = (float *) src1->data;
+ float * dst_ddf = (float *) dst->data;
+
+ ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
+}
+
+static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(!ggml_is_transposed(src0));
+ GGML_ASSERT(!ggml_is_transposed(src1));
+ GGML_ASSERT(!ggml_is_permuted(src0));
+ GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+
+ const int64_t nb01 = src0->nb[1];
+ const int64_t nb02 = src0->nb[2];
+
+ const int64_t ne12 = src1->ne[2];
+
+ cudaStream_t main_stream = ctx.stream();
+
+ void * src0_ddq = src0->data;
+ float * src1_ddf = (float *) src1->data;
+ float * dst_ddf = (float *) dst->data;
+
+ const int64_t row_stride_x = nb01 / sizeof(half);
+ const int64_t channel_stride_x = nb02 / sizeof(half);
+
+ ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
+}
+
+static __global__ void k_compute_batched_ptrs(
+ const half * src0_as_f16, const half * src1_as_f16, char * dst,
+ const void ** ptrs_src, void ** ptrs_dst,
+ int64_t ne12, int64_t ne13,
+ int64_t ne23,
+ size_t nb02, size_t nb03,
+ size_t nb12, size_t nb13,
+ size_t nbd2, size_t nbd3,
+ int64_t r2, int64_t r3) {
+ int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
+ int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
+
+ if (i13 >= ne13 || i12 >= ne12) {
+ return;
+ }
+
+ int64_t i03 = i13 / r3;
+ int64_t i02 = i12 / r2;
+
+ ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
+ ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
+ ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
+}
+
+static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(!ggml_is_transposed(src0));
+ GGML_ASSERT(!ggml_is_transposed(src1));
+
+ GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int64_t ne_dst = ggml_nelements(dst);
+
+ cudaStream_t main_stream = ctx.stream();
+
+ CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
+
+ void * src0_ddq = src0->data;
+ half * src0_f16 = (half *) src0_ddq;
+ float * src1_ddf = (float *) src1->data;
+ float * dst_ddf = (float *) dst->data;
+
+ // convert src1 to fp16
+ ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
+ if (src1->type != GGML_TYPE_F16) {
+ const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+ const int64_t ne_src1 = ggml_nelements(src1);
+ src1_f16_alloc.alloc(ne_src1);
+ GGML_ASSERT(to_fp16_cuda != nullptr);
+ to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
+ }
+ half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get();
+
+ ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
+ char * dst_t;
+
+ cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
+ cudaDataType_t cu_data_type = CUDA_R_16F;
+
+ // dst strides
+ size_t nbd2 = dst->nb[2];
+ size_t nbd3 = dst->nb[3];
+
+ const half alpha_f16 = 1.0f;
+ const half beta_f16 = 0.0f;
+
+ const float alpha_f32 = 1.0f;
+ const float beta_f32 = 0.0f;
+
+ const void * alpha = &alpha_f16;
+ const void * beta = &beta_f16;
+
+ if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+ dst_t = (char *) dst_f16.alloc(ne_dst);
+
+ nbd2 /= sizeof(float) / sizeof(half);
+ nbd3 /= sizeof(float) / sizeof(half);
+ } else {
+ dst_t = (char *) dst_ddf;
+
+ cu_compute_type = CUBLAS_COMPUTE_32F;
+ cu_data_type = CUDA_R_32F;
+
+ alpha = &alpha_f32;
+ beta = &beta_f32;
+ }
+
+ GGML_ASSERT(ne12 % ne02 == 0);
+ GGML_ASSERT(ne13 % ne03 == 0);
+
+ // broadcast factors
+ const int64_t r2 = ne12/ne02;
+ const int64_t r3 = ne13/ne03;
+
+#if 0
+ // use cublasGemmEx
+ {
+ for (int i13 = 0; i13 < ne13; ++i13) {
+ for (int i12 = 0; i12 < ne12; ++i12) {
+ int i03 = i13 / r3;
+ int i02 = i12 / r2;
+
+ CUBLAS_CHECK(
+ cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
+ ne01, ne11, ne10,
+ alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
+ (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
+ beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
+ cu_compute_type,
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+ }
+ }
+ }
+#else
+ if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
+ // there is no broadcast and src0, src1 are contiguous across dims 2, 3
+ // use cublasGemmStridedBatchedEx
+ CUBLAS_CHECK(
+ cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
+ ne01, ne11, ne10,
+ alpha, (const char *) src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
+ (const char *) src1_f16, CUDA_R_16F, nb11/nb10, nb12/nb10, // strideB
+ beta, ( char *) dst_t, cu_data_type, ne01, nb2/nb0, // strideC
+ ne12*ne13,
+ cu_compute_type,
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+ } else {
+ // use cublasGemmBatchedEx
+ const int ne23 = ne12*ne13;
+
+ ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
+ ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
+
+ dim3 block_dims(ne13, ne12);
+ k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
+ src0_f16, src1_f16, dst_t,
+ ptrs_src.get(), ptrs_dst.get(),
+ ne12, ne13,
+ ne23,
+ nb02, nb03,
+ src1->type == GGML_TYPE_F16 ? nb12 : nb12/2,
+ src1->type == GGML_TYPE_F16 ? nb13 : nb13/2,
+ nbd2, nbd3,
+ r2, r3);
+ CUDA_CHECK(cudaGetLastError());
+
+ CUBLAS_CHECK(
+ cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
+ ne01, ne11, ne10,
+ alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
+ (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/nb10,
+ beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
+ ne23,
+ cu_compute_type,
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+ }
+#endif
+
+ if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+ to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
+ }
+}
+
+static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
+
+ bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+ && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[0] >= GGML_CUDA_DMMV_X*2
+ && src1->ne[1] == 1;
+ bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+ && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+ bool use_mul_mat_q = ggml_is_quantized(src0->type)
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+
+ // if mmvq is available it's a better choice than dmmv:
+#ifndef GGML_CUDA_FORCE_DMMV
+ use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
+#endif // GGML_CUDA_FORCE_DMMV
+
+ bool any_gpus_with_slow_fp16 = false;
+
+ if (split) {
+ ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
+ auto & tensor_split = buft_ctx->tensor_split;
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ // skip devices that are not going to do any work:
+ if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
+ continue;
+ }
+
+ const int cc = ggml_cuda_info().devices[id].cc;
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
+ }
+ } else {
+ const int cc = ggml_cuda_info().devices[ctx.device].cc;
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
+ }
+
+ // debug helpers
+ //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
+ //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
+ //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
+ //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
+ //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
+ //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
+
+ if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+ // FP32 precision KQ single-batch for batch size 1 without FlashAttention
+ ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
+ } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
+ // FP32 precision KQV single-batch for batch size 1 without FlashAttention
+ ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
+ } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
+ && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+ // KQ + KQV multi-batch without FlashAttention
+ ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
+ } else if (use_dequantize_mul_mat_vec) {
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
+ } else if (use_mul_mat_vec_q) {
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
+ } else if (use_mul_mat_q) {
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
+ } else {
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
+ }
+}
+
+struct mmid_row_mapping {
+ int32_t i1;
+ int32_t i2;
+};
+
+static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
+ int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
+ const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
+ int64_t ne11, int64_t ne10,
+ size_t nb11, size_t nb12) {
+ int32_t iid1 = blockIdx.x;
+ int32_t id = blockIdx.y;
+
+ const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
+
+ if (row_id_i != i02) {
+ return;
+ }
+
+ const int64_t i11 = id % ne11;
+ const int64_t i12 = iid1;
+
+ __shared__ int src1_row;
+ if (threadIdx.x == 0) {
+ src1_row = atomicAdd(cur_src1_row, 1);
+ row_mapping[src1_row] = {id, iid1};
+ }
+ __syncthreads();
+
+ const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
+ float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
+
+ for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
+ src1_row_contiguous[i] = src1_row_original[i];
+ }
+}
+
+static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
+ const mmid_row_mapping * __restrict__ row_mapping,
+ int64_t ne0,
+ size_t nb1, size_t nb2) {
+ int32_t i = blockIdx.x;
+
+ const int32_t i1 = row_mapping[i].i1;
+ const int32_t i2 = row_mapping[i].i2;
+
+ const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
+ float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
+
+ for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
+ dst_row_original[j] = dst_row_contiguous[j];
+ }
+}
+
+static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * ids = dst->src[2];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
+
+ cudaStream_t stream = ctx.stream();
+
+ const int64_t n_as = ne02;
+ const int64_t n_ids = ids->ne[0];
+
+ std::vector<char> ids_host(ggml_nbytes(ids));
+ const char * ids_dev = (const char *) ids->data;
+ CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
+ CUDA_CHECK(cudaStreamSynchronize(stream));
+
+ ggml_tensor src0_row = *src0;
+ ggml_tensor src1_row = *src1;
+ ggml_tensor dst_row = *dst;
+
+ char * src0_original = (char *) src0->data;
+ char * src1_original = (char *) src1->data;
+ char * dst_original = (char *) dst->data;
+
+ src0_row.ne[2] = 1;
+ src0_row.ne[3] = 1;
+ src0_row.nb[3] = nb02;
+
+ src1_row.ne[1] = 1;
+ src1_row.ne[2] = 1;
+ src1_row.ne[3] = 1;
+ src1_row.nb[2] = nb11;
+ src1_row.nb[3] = nb11;
+
+ dst_row.ne[1] = 1;
+ dst_row.ne[2] = 1;
+ dst_row.ne[3] = 1;
+ dst_row.nb[2] = nb1;
+ dst_row.nb[3] = nb1;
+
+ if (ne12 == 1) {
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+ for (int64_t id = 0; id < n_ids; id++) {
+ const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
+
+ const int64_t i11 = id % ne11;
+ const int64_t i12 = iid1;
+
+ const int64_t i1 = id;
+ const int64_t i2 = i12;
+
+ src0_row.data = src0_original + i02*nb02;
+ src1_row.data = src1_original + i11*nb11 + i12*nb12;
+ dst_row.data = dst_original + i1*nb1 + i2*nb2;
+
+ ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+ }
+ }
+ } else {
+ ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
+ ggml_cuda_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
+
+ src1_row.data = src1_contiguous.get();
+ dst_row.data = dst_contiguous.get();
+
+ for (int64_t i02 = 0; i02 < n_as; i02++) {
+ int64_t num_src1_rows = 0;
+
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+ for (int64_t id = 0; id < n_ids; id++) {
+ const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+
+ GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
+
+ if (row_id_i != i02) {
+ continue;
+ }
+
+ num_src1_rows++;
+ }
+ }
+
+ if (num_src1_rows == 0) {
+ continue;
+ }
+
+ ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
+ ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
+ CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
+
+ {
+ dim3 block_dims(std::min((unsigned int)ne10, 768u));
+ dim3 grid_dims(ids->ne[1], n_ids);
+ k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
+ src1_original, src1_contiguous.get(),
+ dev_cur_src1_row.get(), dev_row_mapping.get(),
+ ids_dev, i02, ids->nb[1], ids->nb[0],
+ ne11, ne10,
+ nb11, nb12);
+ CUDA_CHECK(cudaGetLastError());
+ }
+
+ src0_row.data = src0_original + i02*nb02;
+
+ GGML_ASSERT(nb11 == sizeof(float)*ne10);
+ GGML_ASSERT(nb1 == sizeof(float)*ne0);
+
+ src1_row.ne[1] = num_src1_rows;
+ src1_row.nb[1] = nb11;
+ src1_row.nb[2] = num_src1_rows*nb11;
+ src1_row.nb[3] = num_src1_rows*nb11;
+
+ dst_row.ne[1] = num_src1_rows;
+ dst_row.nb[1] = nb1;
+ dst_row.nb[2] = num_src1_rows*nb1;
+ dst_row.nb[3] = num_src1_rows*nb1;
+
+ ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+
+ {
+ dim3 block_dims(std::min((unsigned int)ne0, 768u));
+ dim3 grid_dims(num_src1_rows);
+ k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
+ dst_original, dst_contiguous.get(),
+ dev_row_mapping.get(),
+ ne0,
+ nb1, nb2);
+ CUDA_CHECK(cudaGetLastError());
+ }
+ }
+ }
+}
+
+static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
+ // why is this here instead of mul_mat?
+ if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) {
+ ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
+ }
+
+ switch (dst->op) {
+ case GGML_OP_REPEAT:
+ ggml_cuda_op_repeat(ctx, dst);
+ break;
+ case GGML_OP_GET_ROWS:
+ ggml_cuda_op_get_rows(ctx, dst);
+ break;
+ case GGML_OP_DUP:
+ ggml_cuda_dup(ctx, dst);
+ break;
+ case GGML_OP_CPY:
+ ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);
+ break;
+ case GGML_OP_CONT:
+ ggml_cuda_dup(ctx, dst);
+ break;
+ case GGML_OP_ADD:
+ ggml_cuda_op_add(ctx, dst);
+ break;
+ case GGML_OP_ACC:
+ ggml_cuda_op_acc(ctx, dst);
+ break;
+ case GGML_OP_MUL:
+ ggml_cuda_op_mul(ctx, dst);
+ break;
+ case GGML_OP_DIV:
+ ggml_cuda_op_div(ctx, dst);
+ break;
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(dst)) {
+ case GGML_UNARY_OP_GELU:
+ ggml_cuda_op_gelu(ctx, dst);
+ break;
+ case GGML_UNARY_OP_SILU:
+ ggml_cuda_op_silu(ctx, dst);
+ break;
+ case GGML_UNARY_OP_GELU_QUICK:
+ ggml_cuda_op_gelu_quick(ctx, dst);
+ break;
+ case GGML_UNARY_OP_TANH:
+ ggml_cuda_op_tanh(ctx, dst);
+ break;
+ case GGML_UNARY_OP_RELU:
+ ggml_cuda_op_relu(ctx, dst);
+ break;
+ case GGML_UNARY_OP_SIGMOID:
+ ggml_cuda_op_sigmoid(ctx, dst);
+ break;
+ case GGML_UNARY_OP_HARDSIGMOID:
+ ggml_cuda_op_hardsigmoid(ctx, dst);
+ break;
+ case GGML_UNARY_OP_HARDSWISH:
+ ggml_cuda_op_hardswish(ctx, dst);
+ break;
+ default:
+ return false;
+ }
+ break;
+ case GGML_OP_NORM:
+ ggml_cuda_op_norm(ctx, dst);
+ break;
+ case GGML_OP_GROUP_NORM:
+ ggml_cuda_op_group_norm(ctx, dst);
+ break;
+ case GGML_OP_CONCAT:
+ ggml_cuda_op_concat(ctx, dst);
+ break;
+ case GGML_OP_UPSCALE:
+ ggml_cuda_op_upscale(ctx, dst);
+ break;
+ case GGML_OP_PAD:
+ ggml_cuda_op_pad(ctx, dst);
+ break;
+ case GGML_OP_ARANGE:
+ ggml_cuda_op_arange(ctx, dst);
+ break;
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ ggml_cuda_op_timestep_embedding(ctx, dst);
+ break;
+ case GGML_OP_LEAKY_RELU:
+ ggml_cuda_op_leaky_relu(ctx, dst);
+ break;
+ case GGML_OP_RMS_NORM:
+ ggml_cuda_op_rms_norm(ctx, dst);
+ break;
+ case GGML_OP_MUL_MAT:
+ if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
+ GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
+ return false;
+ } else {
+ ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
+ }
+ break;
+ case GGML_OP_MUL_MAT_ID:
+ ggml_cuda_mul_mat_id(ctx, dst);
+ break;
+ case GGML_OP_SCALE:
+ ggml_cuda_op_scale(ctx, dst);
+ break;
+ case GGML_OP_SQR:
+ ggml_cuda_op_sqr(ctx, dst);
+ break;
+ case GGML_OP_SQRT:
+ ggml_cuda_op_sqrt(ctx, dst);
+ break;
+ case GGML_OP_CLAMP:
+ ggml_cuda_op_clamp(ctx, dst);
+ break;
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ break;
+ case GGML_OP_DIAG_MASK_INF:
+ ggml_cuda_op_diag_mask_inf(ctx, dst);
+ break;
+ case GGML_OP_SOFT_MAX:
+ ggml_cuda_op_soft_max(ctx, dst);
+ break;
+ case GGML_OP_ROPE:
+ ggml_cuda_op_rope(ctx, dst);
+ break;
+ case GGML_OP_IM2COL:
+ ggml_cuda_op_im2col(ctx, dst);
+ break;
+ case GGML_OP_CONV_TRANSPOSE_1D:
+ ggml_cuda_op_conv_transpose_1d(ctx,dst);
+ break;
+ case GGML_OP_POOL_2D:
+ ggml_cuda_op_pool2d(ctx, dst);
+ break;
+ case GGML_OP_SUM_ROWS:
+ ggml_cuda_op_sum_rows(ctx, dst);
+ break;
+ case GGML_OP_ARGSORT:
+ ggml_cuda_op_argsort(ctx, dst);
+ break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ ggml_cuda_flash_attn_ext(ctx, dst);
+ break;
+ default:
+ return false;
+ }
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess) {
+ GGML_CUDA_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst));
+ CUDA_CHECK(err);
+ }
+
+ return true;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// backend
+
+GGML_CALL static const char * ggml_backend_cuda_name(ggml_backend_t backend) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+ return cuda_ctx->name.c_str();
+}
+
+GGML_CALL static void ggml_backend_cuda_free(ggml_backend_t backend) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+ delete cuda_ctx;
+ delete backend;
+}
+
+GGML_CALL static ggml_backend_buffer_type_t ggml_backend_cuda_get_default_buffer_type(ggml_backend_t backend) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+ return ggml_backend_cuda_buffer_type(cuda_ctx->device);
+}
+
+GGML_CALL static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+ GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
+
+ CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));
+}
+
+GGML_CALL static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+ GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
+
+ CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
+}
+
+GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
+ GGML_ASSERT(ggml_backend_is_cuda(backend_src) || ggml_backend_is_cuda(backend_dst));
+
+ ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
+ ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
+
+ if (!ggml_backend_buffer_is_cuda(src->buffer)) {
+ return false;
+ }
+
+ if (!ggml_backend_buffer_is_cuda(dst->buffer)) {
+ return false;
+ }
+
+ // device -> device
+ ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context;
+ ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context;
+
+ if (backend_src != backend_dst) {
+ ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context;
+ ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context;
+
+ GGML_ASSERT(cuda_ctx_src->device == buf_ctx_src->device);
+ GGML_ASSERT(cuda_ctx_dst->device == buf_ctx_dst->device);
+
+ // copy on src stream
+ if (cuda_ctx_src->device == cuda_ctx_dst->device) {
+ CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_dst->stream()));
+ } else {
+#ifdef GGML_CUDA_NO_PEER_COPY
+ return false;
+#else
+ CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));
+#endif
+ }
+
+ // record event on src stream
+ if (!cuda_ctx_src->copy_event) {
+ ggml_cuda_set_device(cuda_ctx_src->device);
+ CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
+ }
+
+ CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream()));
+
+ // wait on dst stream for the copy to complete
+ CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0));
+ } else {
+ // src and dst are on the same backend
+ CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_dst->stream()));
+ }
+ return true;
+}
+
+GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+ CUDA_CHECK(cudaStreamSynchronize(cuda_ctx->stream()));
+
+ GGML_UNUSED(backend);
+}
+
+static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+ graph_node_properties->node_address = node->data;
+ graph_node_properties->node_op = node->op;
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ graph_node_properties->ne[i] = node->ne[i];
+ graph_node_properties->nb[i] = node->nb[i];
+ }
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
+ }
+}
+
+static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+ if (node->data != graph_node_properties->node_address &&
+ node->op != GGML_OP_CPY &&
+ node->op != GGML_OP_VIEW) {
+ return false;
+ }
+
+ if (node->op != graph_node_properties->node_op) {
+ return false;
+ }
+
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ if (node->ne[i] != graph_node_properties->ne[i]) {
+ return false;
+ }
+ if (node->nb[i] != graph_node_properties->nb[i]) {
+ return false;
+ }
+ }
+
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ if (node->src[i] &&
+ node->src[i]->data != graph_node_properties->src_address[i] &&
+ node->op != GGML_OP_CPY &&
+ node->op != GGML_OP_VIEW
+ ) {
+ return false;
+ }
+ }
+ return true;
+}
+
+GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+ ggml_cuda_set_device(cuda_ctx->device);
+
+#ifdef USE_CUDA_GRAPH
+ static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
+
+ // Objects required for CUDA Graph
+ if (cuda_ctx->cuda_graph == nullptr) {
+ cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
+ }
+
+ bool use_cuda_graph = true;
+ bool cuda_graph_update_required = false;
+ // vector of pointers to CUDA cpy kernels, which are required to identify
+ // kernel parameters which need updated in the graph for each token
+ std::vector<void *> ggml_cuda_cpy_fn_ptrs;
+
+ if (cuda_ctx->cuda_graph->graph == nullptr) {
+ if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
+ cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
+#ifndef NDEBUG
+ GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
+#endif
+ }
+ }
+
+ // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
+ // or previous graph capture failure.
+ // Also disable for multi-gpu for now. TO DO investigate
+ if (disable_cuda_graphs_due_to_env
+ || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
+ || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
+ || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
+ use_cuda_graph = false;
+ }
+
+ if (use_cuda_graph) {
+ if (cuda_ctx->cuda_graph->instance == nullptr) {
+ cuda_graph_update_required = true;
+ }
+
+ // Check if the graph size has changed
+ if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
+ cuda_graph_update_required = true;
+ cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
+ }
+
+ // Loop over nodes in GGML graph to determine if CUDA graph update is required
+ // and store properties to allow this comparison for the next token
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ bool has_matching_properties = true;
+ if (!cuda_graph_update_required) {
+ has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+ }
+ if (!has_matching_properties) {
+ cuda_graph_update_required = true;
+ }
+ set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+ }
+
+ // Loop over nodes in GGML graph to obtain info needed for CUDA graph
+ cuda_ctx->cuda_graph->updated_kernel_arg.clear();
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ ggml_tensor * node = cgraph->nodes[i];
+
+ if (node->src[0] && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) {
+ use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
+#ifndef NDEBUG
+ GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to split buffer\n", __func__);
+#endif
+ }
+
+ if (node->op == GGML_OP_MUL_MAT_ID) {
+ use_cuda_graph = false; // This node type is not supported by CUDA graph capture
+#ifndef NDEBUG
+ GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
+#endif
+ }
+
+ if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
+ // disable CUDA graphs for batch size > 1 for now.
+ // Changes in batch size or context size can cause changes to the grid size of some kernels.
+ use_cuda_graph = false;
+#ifndef NDEBUG
+ GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
+#endif
+ }
+
+ if (node->op == GGML_OP_CPY) {
+ // store the copy op parameter which changes with each token.
+ cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
+ // store a pointer to each copy op CUDA kernel to identify it later
+ void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+ if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
+ ggml_cuda_cpy_fn_ptrs.push_back(ptr);
+ }
+ }
+
+ if (!use_cuda_graph) {
+ break;
+ }
+ }
+
+ // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
+ if (use_cuda_graph && cuda_graph_update_required) {
+ cuda_ctx->cuda_graph->number_consecutive_updates++;
+ } else {
+ cuda_ctx->cuda_graph->number_consecutive_updates = 0;
+ }
+
+ if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
+ cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
+#ifndef NDEBUG
+ GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
+#endif
+ }
+ }
+
+ if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
+ CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
+ }
+
+#else
+ bool use_cuda_graph = false;
+ bool cuda_graph_update_required = false;
+#endif // USE_CUDA_GRAPH
+
+ bool graph_evaluated_or_captured = false;
+
+ while (!graph_evaluated_or_captured) {
+ // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
+ // With the use of CUDA graphs, the execution will be performed by the graph launch.
+ if (!use_cuda_graph || cuda_graph_update_required) {
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ ggml_tensor * node = cgraph->nodes[i];
+
+ if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+ continue;
+ }
+
+#ifndef NDEBUG
+ assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ if (node->src[j] != nullptr) {
+ assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
+ }
+ }
+#endif
+
+ bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
+ if (!ok) {
+ GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+ }
+ GGML_ASSERT(ok);
+ }
+ }
+
+#ifdef USE_CUDA_GRAPH
+ if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
+ if (cuda_ctx->cuda_graph->graph != nullptr) {
+ CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
+ cuda_ctx->cuda_graph->graph = nullptr;
+ }
+ CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
+
+#if 0
+ if (disable_cuda_graphs_due_to_failed_capture) {
+ use_cuda_graph = false;
+ cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
+#ifndef NDEBUG
+ GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
+#endif
+ } else {
+ graph_evaluated_or_captured = true; // CUDA graph has been captured
+ }
+#endif
+ graph_evaluated_or_captured = true; // CUDA graph has been captured
+ } else {
+ graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
+ }
+ }
+
+ if (use_cuda_graph) {
+ if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
+ CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+ }
+
+ // Perform update to graph (if required for this token), and change copy parameter (required for every token)
+
+ if (cuda_graph_update_required) {
+ // Extract nodes from graph
+ // First call with null argument gets number of nodes in graph
+ CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
+ // Subsequent call with non-null argument gets nodes
+ cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
+ cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
+ if (cuda_ctx->cuda_graph->num_nodes > 0) {
+ CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
+
+ // Loop over nodes, and extract kernel parameters from each node
+ for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+ cudaGraphNodeType node_type;
+ CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
+ if (node_type == cudaGraphNodeTypeKernel) {
+ cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
+ if (stat == cudaErrorInvalidDeviceFunction) {
+ // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
+ // We don't need to update blas nodes, so clear error and move on.
+ cudaGetLastError();
+ } else {
+ GGML_ASSERT(stat == cudaSuccess);
+ }
+ }
+ }
+ }
+ }
+
+ // One of the arguments to the copy kernel is updated for each token, hence we need to
+ // replace that argument with the updated value in the CUDA graph
+ if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
+ int k = 0;
+ for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+ if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
+ char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
+ cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
+ CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
+ }
+ }
+ }
+
+ // Update graph executable
+ cudaGraphExecUpdateResultInfo result_info;
+ cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
+ if (stat == cudaErrorGraphExecUpdateFailure) {
+#ifndef NDEBUG
+ GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__);
+#endif
+ // The pre-existing graph exec cannot be updated due to violated constraints
+ // so instead clear error and re-instantiate
+ cudaGetLastError();
+ CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
+ cuda_ctx->cuda_graph->instance = nullptr;
+ CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+ } else {
+ GGML_ASSERT(stat == cudaSuccess);
+ }
+ // Launch graph
+ CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
+#else
+ graph_evaluated_or_captured = true;
+#endif // USE_CUDA_GRAPH
+ }
+
+ return GGML_STATUS_SUCCESS;
+}
+
+GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
+ switch (op->op) {
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(op)) {
+ case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_SIGMOID:
+ case GGML_UNARY_OP_HARDSIGMOID:
+ case GGML_UNARY_OP_HARDSWISH:
+ case GGML_UNARY_OP_GELU_QUICK:
+ case GGML_UNARY_OP_TANH:
+ return ggml_is_contiguous(op->src[0]);
+ default:
+ return false;
+ }
+ break;
+ case GGML_OP_MUL_MAT:
+ case GGML_OP_MUL_MAT_ID:
+ {
+ struct ggml_tensor * a = op->src[0];
+ if (op->op == GGML_OP_MUL_MAT) {
+ struct ggml_tensor * b = op->src[1];
+ if (a->ne[3] != b->ne[3]) {
+ return false;
+ }
+ }
+ switch (a->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_Q8_K:
+ case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
+ return true;
+ default:
+ return false;
+ }
+ } break;
+ case GGML_OP_GET_ROWS:
+ {
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ return true;
+ default:
+ return false;
+ }
+ } break;
+ case GGML_OP_CPY:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ ggml_type src1_type = op->src[1]->type;
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ return false;
+ } break;
+ case GGML_OP_DUP:
+ case GGML_OP_REPEAT:
+ case GGML_OP_CONCAT:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+ } break;
+ case GGML_OP_CONV_TRANSPOSE_1D:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ ggml_type src1_type = op->src[1]->type;
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ return false;
+ } break;
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_NORM:
+ case GGML_OP_ADD:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_SCALE:
+ case GGML_OP_SQR:
+ case GGML_OP_SQRT:
+ case GGML_OP_CLAMP:
+ case GGML_OP_CONT:
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_SOFT_MAX:
+ return true;
+ case GGML_OP_ROPE:
+ return ggml_is_contiguous(op->src[0]);
+ case GGML_OP_IM2COL:
+ case GGML_OP_POOL_2D:
+ case GGML_OP_SUM_ROWS:
+ case GGML_OP_ARGSORT:
+ case GGML_OP_ACC:
+ case GGML_OP_GROUP_NORM:
+ case GGML_OP_UPSCALE:
+ case GGML_OP_PAD:
+ case GGML_OP_ARANGE:
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ case GGML_OP_LEAKY_RELU:
+ return true;
+ case GGML_OP_FLASH_ATTN_EXT:
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+ return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
+#else
+ if (op->src[0]->ne[0] == 128) {
+ return true;
+ }
+ if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
+ return true;
+ }
+ return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
+ op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+ default:
+ return false;
+ }
+
+ GGML_UNUSED(backend);
+}
+
+GGML_CALL static bool ggml_backend_cuda_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+ if (ggml_backend_buft_is_cuda_split(buft)) {
+ return true;
+ }
+
+ if (ggml_backend_buft_is_cuda(buft)) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+ ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
+ return buft_ctx->device == cuda_ctx->device;
+ }
+
+ return false;
+}
+
+GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
+ const int min_batch_size = 32;
+
+ return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
+ (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
+
+ GGML_UNUSED(backend);
+}
+
+static ggml_backend_event_t ggml_backend_cuda_event_new(ggml_backend_t backend) {
+#ifdef GGML_CUDA_NO_PEER_COPY
+ return nullptr;
+#else
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+ ggml_cuda_set_device(cuda_ctx->device);
+
+ cudaEvent_t event;
+ CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
+
+ return new ggml_backend_event {
+ /* .backend = */ backend,
+ /* .context = */ event,
+ };
+#endif
+}
+
+static void ggml_backend_cuda_event_free(ggml_backend_event_t event) {
+ CUDA_CHECK(cudaEventDestroy((cudaEvent_t)event->context));
+
+ delete event;
+}
+
+static void ggml_backend_cuda_event_record(ggml_backend_event_t event) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)event->backend->context;
+
+ CUDA_CHECK(cudaEventRecord((cudaEvent_t)event->context, cuda_ctx->stream()));
+}
+
+static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+ if (ggml_backend_is_cuda(event->backend)) {
+ CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), (cudaEvent_t)event->context, 0));
+ } else {
+#if 0
+ // untested
+ auto wait_fn = [](void * user_data) {
+ ggml_backend_event_t event = (ggml_backend_event_t)user_data;
+ ggml_backend_event_synchronize(event);
+ };
+
+ CUDA_CHECK(cudaLaunchHostFunc(cuda_ctx->stream(), wait_fn, event));
+#endif
+ GGML_ASSERT(false);
+ }
+}
+
+static void ggml_backend_cuda_event_synchronize(ggml_backend_event_t event) {
+ CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
+}
+
+static ggml_backend_i ggml_backend_cuda_interface = {
+ /* .get_name = */ ggml_backend_cuda_name,
+ /* .free = */ ggml_backend_cuda_free,
+ /* .get_default_buffer_type = */ ggml_backend_cuda_get_default_buffer_type,
+ /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
+ /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
+ /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async,
+ /* .synchronize = */ ggml_backend_cuda_synchronize,
+ /* .graph_plan_create = */ NULL,
+ /* .graph_plan_free = */ NULL,
+ /* .graph_plan_update = */ NULL,
+ /* .graph_plan_compute = */ NULL,
+ /* .graph_compute = */ ggml_backend_cuda_graph_compute,
+ /* .supports_op = */ ggml_backend_cuda_supports_op,
+ /* .supports_buft = */ ggml_backend_cuda_supports_buft,
+ /* .offload_op = */ ggml_backend_cuda_offload_op,
+ /* .event_new = */ ggml_backend_cuda_event_new,
+ /* .event_free = */ ggml_backend_cuda_event_free,
+ /* .event_record = */ ggml_backend_cuda_event_record,
+ /* .event_wait = */ ggml_backend_cuda_event_wait,
+ /* .event_synchronize = */ ggml_backend_cuda_event_synchronize,
+};
+
+static ggml_guid_t ggml_backend_cuda_guid() {
+ static ggml_guid guid = { 0x2c, 0xdd, 0xe8, 0x1c, 0x65, 0xb3, 0x65, 0x73, 0x6a, 0x12, 0x88, 0x61, 0x1c, 0xc9, 0xdc, 0x25 };
+ return &guid;
+}
+
+GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
+ if (device < 0 || device >= ggml_backend_cuda_get_device_count()) {
+ GGML_CUDA_LOG_ERROR("%s: invalid device %d\n", __func__, device);
+ return nullptr;
+ }
+
+ ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device);
+ if (ctx == nullptr) {
+ GGML_CUDA_LOG_ERROR("%s: failed to allocate context\n", __func__);
+ return nullptr;
+ }
+
+ ggml_backend_t cuda_backend = new ggml_backend {
+ /* .guid = */ ggml_backend_cuda_guid(),
+ /* .interface = */ ggml_backend_cuda_interface,
+ /* .context = */ ctx
+ };
+
+ return cuda_backend;
+}
+
+GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend) {
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cuda_guid());
+}
+
+GGML_CALL int ggml_backend_cuda_get_device_count() {
+ return ggml_cuda_info().device_count;
+}
+
+GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size) {
+ cudaDeviceProp prop;
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
+ snprintf(description, description_size, "%s", prop.name);
+}
+
+GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total) {
+ ggml_cuda_set_device(device);
+
+ CUDA_CHECK(cudaMemGetInfo(free, total));
+}
+
+GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
+ if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
+ return false;
+ }
+
+#if CUDART_VERSION >= 11100
+ cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
+ if (err != cudaSuccess) {
+ // clear the error
+ cudaGetLastError();
+
+ GGML_CUDA_LOG_WARN("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
+ size / 1024.0 / 1024.0, cudaGetErrorString(err));
+ return false;
+ }
+ return true;
+#else
+ return false;
+#endif
+}
+
+GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
+ if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
+ return;
+ }
+
+ cudaError_t err = cudaHostUnregister(buffer);
+ if (err != cudaSuccess) {
+ // clear the error
+ cudaGetLastError();
+ }
+}
+
+// backend registry
+GGML_CALL static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * user_data) {
+ ggml_backend_t cuda_backend = ggml_backend_cuda_init((int) (intptr_t) user_data);
+ return cuda_backend;
+
+ GGML_UNUSED(params);
+}
+
+extern "C" GGML_CALL int ggml_backend_cuda_reg_devices();
+
+GGML_CALL int ggml_backend_cuda_reg_devices() {
+ int device_count = ggml_backend_cuda_get_device_count();
+ //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
+ for (int i = 0; i < device_count; i++) {
+ char name[128];
+ snprintf(name, sizeof(name), "%s%d", GGML_CUDA_NAME, i);
+ ggml_backend_register(name, ggml_backend_reg_cuda_init, ggml_backend_cuda_buffer_type(i), (void *) (intptr_t) i);
+ }
+ return device_count;
+}
diff --git a/ggml/src/ggml-cuda/acc.cu b/ggml/src/ggml-cuda/acc.cu
new file mode 100644
index 00000000..96bfe1c9
--- /dev/null
+++ b/ggml/src/ggml-cuda/acc.cu
@@ -0,0 +1,47 @@
+#include "acc.cuh"
+
+static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
+ const int ne10, const int ne11, const int ne12,
+ const int nb1, const int nb2, int offset) {
+ const int i = blockDim.x * blockIdx.x + threadIdx.x;
+ if (i >= ne) {
+ return;
+ }
+ int src1_idx = i - offset;
+ int oz = src1_idx / nb2;
+ int oy = (src1_idx - (oz * nb2)) / nb1;
+ int ox = src1_idx % nb1;
+ if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
+ dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
+ } else {
+ dst[i] = x[i];
+ }
+}
+
+static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
+ const int ne10, const int ne11, const int ne12,
+ const int nb1, const int nb2, const int offset, cudaStream_t stream) {
+ int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
+ acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
+}
+
+void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const float * src0_d = (const float *)src0->data;
+ const float * src1_d = (const float *)src1->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
+
+ int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
+ int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
+ // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
+ int offset = dst->op_params[3] / 4; // offset in bytes
+
+ acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream);
+}
diff --git a/ggml/src/ggml-cuda/acc.cuh b/ggml/src/ggml-cuda/acc.cuh
new file mode 100644
index 00000000..1168ea1b
--- /dev/null
+++ b/ggml/src/ggml-cuda/acc.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_ACC_BLOCK_SIZE 256
+
+void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/arange.cu b/ggml/src/ggml-cuda/arange.cu
new file mode 100644
index 00000000..b5e495a2
--- /dev/null
+++ b/ggml/src/ggml-cuda/arange.cu
@@ -0,0 +1,34 @@
+#include "arange.cuh"
+
+static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
+ // blockIDx.x: idx of ne0 / BLOCK_SIZE
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (nidx >= ne0) {
+ return;
+ }
+ dst[nidx] = start + step * nidx;
+}
+
+static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
+ int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;
+ arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start, step);
+}
+
+void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ float start;
+ float stop;
+ float step;
+ memcpy(&start, (float *)dst->op_params + 0, sizeof(float));
+ memcpy(&stop, (float *)dst->op_params + 1, sizeof(float));
+ memcpy(&step, (float *)dst->op_params + 2, sizeof(float));
+
+ int64_t steps = (int64_t)ceil((stop - start) / step);
+ GGML_ASSERT(ggml_nelements(dst) == steps);
+
+ arange_f32_cuda(dst_d, dst->ne[0], start, step, stream);
+}
diff --git a/ggml/src/ggml-cuda/arange.cuh b/ggml/src/ggml-cuda/arange.cuh
new file mode 100644
index 00000000..41e74fdf
--- /dev/null
+++ b/ggml/src/ggml-cuda/arange.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_ARANGE_BLOCK_SIZE 256
+
+void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu
new file mode 100644
index 00000000..15757ca1
--- /dev/null
+++ b/ggml/src/ggml-cuda/argsort.cu
@@ -0,0 +1,104 @@
+#include "argsort.cuh"
+
+template<typename T>
+static inline __device__ void ggml_cuda_swap(T & a, T & b) {
+ T tmp = a;
+ a = b;
+ b = tmp;
+}
+
+template<ggml_sort_order order>
+static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
+ // bitonic sort
+ int col = threadIdx.x;
+ int row = blockIdx.y;
+
+ if (col >= ncols_pad) {
+ return;
+ }
+
+ const float * x_row = x + row * ncols;
+ extern __shared__ int dst_row[];
+
+ // initialize indices
+ dst_row[col] = col;
+
+ __syncthreads();
+
+ for (int k = 2; k <= ncols_pad; k *= 2) {
+ for (int j = k / 2; j > 0; j /= 2) {
+ int ixj = col ^ j;
+ if (ixj > col) {
+ if ((col & k) == 0) {
+ if (dst_row[col] >= ncols ||
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+ ) {
+ ggml_cuda_swap(dst_row[col], dst_row[ixj]);
+ }
+ } else {
+ if (dst_row[ixj] >= ncols ||
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+ ) {
+ ggml_cuda_swap(dst_row[col], dst_row[ixj]);
+ }
+ }
+ }
+ __syncthreads();
+ }
+ }
+
+ // copy the result to dst without the padding
+ if (col < ncols) {
+ dst[row * ncols + col] = dst_row[col];
+ }
+}
+
+static int next_power_of_2(int x) {
+ int n = 1;
+ while (n < x) {
+ n *= 2;
+ }
+ return n;
+}
+
+static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
+ // bitonic sort requires ncols to be power of 2
+ const int ncols_pad = next_power_of_2(ncols);
+
+ const dim3 block_dims(ncols_pad, 1, 1);
+ const dim3 block_nums(1, nrows, 1);
+ const size_t shared_mem = ncols_pad * sizeof(int);
+
+ // FIXME: this limit could be raised by ~2-4x on Ampere or newer
+ GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
+
+ if (order == GGML_SORT_ORDER_ASC) {
+ k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+ } else if (order == GGML_SORT_ORDER_DESC) {
+ k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+ } else {
+ GGML_ASSERT(false);
+ }
+}
+
+void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+
+ argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
+}
diff --git a/ggml/src/ggml-cuda/argsort.cuh b/ggml/src/ggml-cuda/argsort.cuh
new file mode 100644
index 00000000..68a00154
--- /dev/null
+++ b/ggml/src/ggml-cuda/argsort.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu
new file mode 100644
index 00000000..76cc01b2
--- /dev/null
+++ b/ggml/src/ggml-cuda/binbcast.cu
@@ -0,0 +1,316 @@
+#include "binbcast.cuh"
+
+static __device__ __forceinline__ float op_repeat(const float a, const float b) {
+ return b;
+ GGML_UNUSED(a);
+}
+
+static __device__ __forceinline__ float op_add(const float a, const float b) {
+ return a + b;
+}
+
+static __device__ __forceinline__ float op_mul(const float a, const float b) {
+ return a * b;
+}
+
+static __device__ __forceinline__ float op_div(const float a, const float b) {
+ return a / b;
+}
+
+template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
+ int ne0, int ne1, int ne2, int ne3,
+ int ne10, int ne11, int ne12, int ne13,
+ /*int s0, */ int s1, int s2, int s3,
+ /*int s00,*/ int s01, int s02, int s03,
+ /*int s10,*/ int s11, int s12, int s13) {
+ const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
+ const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
+ const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
+ const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
+
+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+ return;
+ }
+
+ const int i11 = i1 % ne11;
+ const int i12 = i2 % ne12;
+ const int i13 = i3 % ne13;
+
+ const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+ const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
+
+ const src0_t * src0_row = src0 + i_src0;
+ const src1_t * src1_row = src1 + i_src1;
+ dst_t * dst_row = dst + i_dst;
+
+ for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
+ const int i10 = i0 % ne10;
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+ }
+}
+
+template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
+ int ne0, int ne1, int ne2, int ne3,
+ int ne10, int ne11, int ne12, int ne13,
+ /*int s0, */ int s1, int s2, int s3,
+ /*int s00,*/ int s01, int s02, int s03,
+ /*int s10,*/ int s11, int s12, int s13) {
+
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ const int i3 = i/(ne2*ne1*ne0);
+ const int i2 = (i/(ne1*ne0)) % ne2;
+ const int i1 = (i/ne0) % ne1;
+ const int i0 = i % ne0;
+
+ if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+ return;
+ }
+
+ const int i11 = i1 % ne11;
+ const int i12 = i2 % ne12;
+ const int i13 = i3 % ne13;
+
+ const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+ const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
+
+ const src0_t * src0_row = src0 + i_src0;
+ const src1_t * src1_row = src1 + i_src1;
+ dst_t * dst_row = dst + i_dst;
+
+ const int i10 = i0 % ne10;
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+}
+
+template<float (*bin_op)(const float, const float)>
+struct bin_bcast_cuda {
+ template<typename src0_t, typename src1_t, typename dst_t>
+ void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
+ const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
+ cudaStream_t stream) {
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ int nr0 = ne10/ne0;
+ int nr1 = ne11/ne1;
+ int nr2 = ne12/ne2;
+ int nr3 = ne13/ne3;
+
+ int nr[4] = { nr0, nr1, nr2, nr3 };
+
+ // collapse dimensions until first broadcast dimension
+ int64_t cne[] = {ne0, ne1, ne2, ne3};
+ int64_t cne0[] = {ne00, ne01, ne02, ne03};
+ int64_t cne1[] = {ne10, ne11, ne12, ne13};
+
+ size_t cnb[] = {nb0, nb1, nb2, nb3};
+ size_t cnb0[] = {nb00, nb01, nb02, nb03};
+ size_t cnb1[] = {nb10, nb11, nb12, nb13};
+
+ auto collapse = [](int64_t cne[]) {
+ cne[0] *= cne[1];
+ cne[1] = cne[2];
+ cne[2] = cne[3];
+ cne[3] = 1;
+ };
+
+ auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
+ cnb[1] *= cne[1];
+ cnb[2] *= cne[2];
+ cnb[3] *= cne[3];
+ };
+
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+ for (int i = 0; i < 4; i++) {
+ if (nr[i] != 1) {
+ break;
+ }
+ if (i > 0) {
+ collapse_nb(cnb, cne);
+ collapse_nb(cnb0, cne0);
+ collapse_nb(cnb1, cne1);
+ collapse(cne);
+ collapse(cne0);
+ collapse(cne1);
+ }
+ }
+ }
+
+ {
+ int64_t ne0 = cne[0];
+ int64_t ne1 = cne[1];
+ int64_t ne2 = cne[2];
+ int64_t ne3 = cne[3];
+
+ //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
+ //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
+ //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
+ //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
+
+ int64_t ne10 = cne1[0];
+ int64_t ne11 = cne1[1];
+ int64_t ne12 = cne1[2];
+ int64_t ne13 = cne1[3];
+
+ size_t nb0 = cnb[0];
+ size_t nb1 = cnb[1];
+ size_t nb2 = cnb[2];
+ size_t nb3 = cnb[3];
+
+ size_t nb00 = cnb0[0];
+ size_t nb01 = cnb0[1];
+ size_t nb02 = cnb0[2];
+ size_t nb03 = cnb0[3];
+
+ size_t nb10 = cnb1[0];
+ size_t nb11 = cnb1[1];
+ size_t nb12 = cnb1[2];
+ size_t nb13 = cnb1[3];
+
+ size_t s0 = nb0 / sizeof(dst_t);
+ size_t s1 = nb1 / sizeof(dst_t);
+ size_t s2 = nb2 / sizeof(dst_t);
+ size_t s3 = nb3 / sizeof(dst_t);
+
+ size_t s10 = nb10 / sizeof(src1_t);
+ size_t s11 = nb11 / sizeof(src1_t);
+ size_t s12 = nb12 / sizeof(src1_t);
+ size_t s13 = nb13 / sizeof(src1_t);
+
+ size_t s00 = nb00 / sizeof(src0_t);
+ size_t s01 = nb01 / sizeof(src0_t);
+ size_t s02 = nb02 / sizeof(src0_t);
+ size_t s03 = nb03 / sizeof(src0_t);
+
+ GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
+
+ GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
+
+ GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
+
+ GGML_ASSERT(s0 == 1);
+ GGML_ASSERT(s00 == 1);
+ GGML_ASSERT(s10 == 1);
+
+ const int block_size = 128;
+
+ int64_t hne0 = std::max(ne0/2LL, 1LL);
+
+ dim3 block_dims;
+ block_dims.x = std::min<unsigned int>(hne0, block_size);
+ block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
+ block_dims.z = std::min(std::min<unsigned int>(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
+
+ dim3 block_nums(
+ (hne0 + block_dims.x - 1) / block_dims.x,
+ (ne1 + block_dims.y - 1) / block_dims.y,
+ (ne2*ne3 + block_dims.z - 1) / block_dims.z
+ );
+
+ if (block_nums.z > 65535) {
+ // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
+ int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
+ k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00, */ s01, s02, s03,
+ /* s10, */ s11, s12, s13);
+ } else {
+ k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd,
+ ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13,
+ /* s0, */ s1, s2, s3,
+ /* s00, */ s01, s02, s03,
+ /* s10, */ s11, s12, s13);
+ }
+ }
+ }
+};
+
+template<class op>
+static void ggml_cuda_op_bin_bcast(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
+
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+ op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+ op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
+ } else {
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
+ ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+ GGML_ASSERT(false);
+ }
+}
+
+void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
+}
+
+void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
+static __global__ void scale_f32_l(const float * x, float * dst, const void * data, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ const float * scale = (const float *)data;
+ dst[i] = scale[0] * x[i];
+}
+
+static void scale_f32_cuda_l(const float * x, float * dst, const void * data, const int k, cudaStream_t stream) {
+ constexpr int CUDA_SCALE_BLOCK_SIZE = 512; //256;
+ const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
+ scale_f32_l<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, data, k);
+}
+
+void ggml_cuda_op_scale_tensor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ float scale;
+ memcpy(&scale, dst->src[1]->data, sizeof(float));
+
+ scale_f32_cuda_l(src0_d, dst_d, dst->src[1]->data, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ if (ggml_nelements(dst->src[1]) == 1 && dst->src[1]->type == GGML_TYPE_F32 && dst->src[0]->type == GGML_TYPE_F32) {
+ ggml_cuda_op_scale_tensor(ctx, dst);
+ return;
+ }
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
+void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh
new file mode 100644
index 00000000..4f63d637
--- /dev/null
+++ b/ggml/src/ggml-cuda/binbcast.cuh
@@ -0,0 +1,6 @@
+#include "common.cuh"
+
+void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/clamp.cu b/ggml/src/ggml-cuda/clamp.cu
new file mode 100644
index 00000000..8009a3e3
--- /dev/null
+++ b/ggml/src/ggml-cuda/clamp.cu
@@ -0,0 +1,34 @@
+#include "clamp.cuh"
+
+static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
+}
+
+static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
+ clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
+}
+
+
+void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ float min;
+ float max;
+ memcpy(&min, dst->op_params, sizeof(float));
+ memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
+
+ clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream);
+}
diff --git a/ggml/src/ggml-cuda/clamp.cuh b/ggml/src/ggml-cuda/clamp.cuh
new file mode 100644
index 00000000..7f9559dd
--- /dev/null
+++ b/ggml/src/ggml-cuda/clamp.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_CLAMP_BLOCK_SIZE 256
+
+void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
new file mode 100644
index 00000000..7ea93264
--- /dev/null
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -0,0 +1,871 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-cuda.h"
+
+#include <cstdint>
+#include <memory>
+
+#if defined(GGML_USE_HIPBLAS)
+#define GGML_COMMON_DECL_HIP
+#define GGML_COMMON_IMPL_HIP
+#else
+#define GGML_COMMON_DECL_CUDA
+#define GGML_COMMON_IMPL_CUDA
+#endif
+#include "ggml-common.h"
+
+#include <cstdio>
+#include <array>
+#include <cassert>
+#include <cfloat>
+#include <string>
+#include <vector>
+
+#if defined(GGML_USE_HIPBLAS)
+#include <hip/hip_runtime.h>
+#include <hipblas/hipblas.h>
+#include <hip/hip_fp16.h>
+#ifdef __HIP_PLATFORM_AMD__
+// for rocblas_initialize()
+#include "rocblas/rocblas.h"
+#endif // __HIP_PLATFORM_AMD__
+#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
+#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
+#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
+#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
+#define CUBLAS_OP_N HIPBLAS_OP_N
+#define CUBLAS_OP_T HIPBLAS_OP_T
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_TF32_TENSOR_OP_MATH 0
+#define CUDA_R_16F HIPBLAS_R_16F
+#define CUDA_R_32F HIPBLAS_R_32F
+#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
+#define cublasCreate hipblasCreate
+#define cublasDestroy hipblasDestroy
+#define cublasGemmEx hipblasGemmEx
+#define cublasGemmBatchedEx hipblasGemmBatchedEx
+#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
+#define cublasHandle_t hipblasHandle_t
+#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
+#define cublasSetStream hipblasSetStream
+#define cublasSgemm hipblasSgemm
+#define cublasStatus_t hipblasStatus_t
+#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
+#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
+#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
+#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
+#define cudaDeviceProp hipDeviceProp_t
+#define cudaDeviceSynchronize hipDeviceSynchronize
+#define cudaError_t hipError_t
+#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
+#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
+#define cudaEventCreateWithFlags hipEventCreateWithFlags
+#define cudaEventDisableTiming hipEventDisableTiming
+#define cudaEventRecord hipEventRecord
+#define cudaEventSynchronize hipEventSynchronize
+#define cudaEvent_t hipEvent_t
+#define cudaEventDestroy hipEventDestroy
+#define cudaFree hipFree
+#define cudaFreeHost hipHostFree
+#define cudaGetDevice hipGetDevice
+#define cudaGetDeviceCount hipGetDeviceCount
+#define cudaGetDeviceProperties hipGetDeviceProperties
+#define cudaGetErrorString hipGetErrorString
+#define cudaGetLastError hipGetLastError
+#define cudaHostRegister hipHostRegister
+#define cudaHostRegisterPortable hipHostRegisterPortable
+#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
+#define cudaHostUnregister hipHostUnregister
+#define cudaLaunchHostFunc hipLaunchHostFunc
+#define cudaMalloc hipMalloc
+#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
+#define cudaMemcpy hipMemcpy
+#define cudaMemcpyAsync hipMemcpyAsync
+#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
+#define cudaMemcpy2DAsync hipMemcpy2DAsync
+#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
+#define cudaMemcpyKind hipMemcpyKind
+#define cudaMemset hipMemset
+#define cudaMemsetAsync hipMemsetAsync
+#define cudaMemGetInfo hipMemGetInfo
+#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
+#define cudaSetDevice hipSetDevice
+#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
+#define cudaStreamDestroy hipStreamDestroy
+#define cudaStreamFireAndForget hipStreamFireAndForget
+#define cudaStreamNonBlocking hipStreamNonBlocking
+#define cudaStreamPerThread hipStreamPerThread
+#define cudaStreamSynchronize hipStreamSynchronize
+#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
+#define cudaStream_t hipStream_t
+#define cudaSuccess hipSuccess
+#define __trap() do { abort(); __builtin_unreachable(); } while(0)
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
+#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
+#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
+#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
+#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
+#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
+#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
+#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
+#else
+#include <cuda_runtime.h>
+#include <cuda.h>
+#include <cublas_v2.h>
+#include <cuda_fp16.h>
+
+#if CUDART_VERSION < 11020
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
+#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
+#define CUBLAS_COMPUTE_16F CUDA_R_16F
+#define CUBLAS_COMPUTE_32F CUDA_R_32F
+#define cublasComputeType_t cudaDataType_t
+#endif // CUDART_VERSION < 11020
+
+#endif // defined(GGML_USE_HIPBLAS)
+
+#define STRINGIZE_IMPL(...) #__VA_ARGS__
+#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
+
+#define WARP_SIZE 32
+#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
+#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
+
+#define CC_PASCAL 600
+#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define CC_VOLTA 700
+#define CC_TURING 750
+#define CC_AMPERE 800
+#define CC_OFFSET_AMD 1000000
+#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
+#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
+#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
+
+#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+#define GGML_CUDA_MAX_STREAMS 8
+
+[[noreturn]]
+void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
+
+#define CUDA_CHECK_GEN(err, success, error_fn) \
+ do { \
+ auto err_ = (err); \
+ if (err_ != (success)) { \
+ ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
+ } \
+ } while (0)
+
+#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
+
+#if CUDART_VERSION >= 12000
+ static const char * cublas_get_error_str(const cublasStatus_t err) {
+ return cublasGetStatusString(err);
+ }
+#else
+ static const char * cublas_get_error_str(const cublasStatus_t err) {
+ switch (err) {
+ case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
+ case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
+ case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
+ case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
+ case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
+ case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
+ case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
+ case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
+ case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
+ default: return "unknown error";
+ }
+ }
+#endif // CUDART_VERSION >= 12000
+
+#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
+
+#if !defined(GGML_USE_HIPBLAS)
+static const char * cu_get_error_str(CUresult err) {
+ const char * err_str;
+ cuGetErrorString(err, &err_str);
+ return err_str;
+}
+#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
+#endif
+
+#if CUDART_VERSION >= 11100
+#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
+#else
+#define GGML_CUDA_ASSUME(x)
+#endif // CUDART_VERSION >= 11100
+
+#ifdef GGML_CUDA_F16
+typedef half dfloat; // dequantize float
+typedef half2 dfloat2;
+#else
+typedef float dfloat; // dequantize float
+typedef float2 dfloat2;
+#endif //GGML_CUDA_F16
+
+#if defined(GGML_USE_HIPBLAS)
+#define __CUDA_ARCH__ 1300
+
+#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
+ defined(__gfx1150__) || defined(__gfx1151__)
+#define RDNA3
+#endif
+
+#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
+ defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
+#define RDNA2
+#endif
+
+#if defined(__gfx1010__) || defined(__gfx1012__)
+#define RDNA1
+#endif
+
+#ifndef __has_builtin
+ #define __has_builtin(x) 0
+#endif
+
+typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
+typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
+static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
+ const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+ const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+#if __has_builtin(__builtin_elementwise_sub_sat)
+ const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
+ return reinterpret_cast<const int &>(c);
+#else
+ int8x4_t c;
+ int16_t tmp;
+#pragma unroll
+ for (int i = 0; i < 4; i++) {
+ tmp = va[i] - vb[i];
+ if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
+ if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
+ c[i] = tmp;
+ }
+ return reinterpret_cast<int &>(c);
+#endif // __has_builtin(__builtin_elementwise_sub_sat)
+}
+
+static __device__ __forceinline__ int __vsub4(const int a, const int b) {
+ return __vsubss4(a, b);
+}
+
+static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
+ const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
+ const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
+ unsigned int c;
+ uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
+#pragma unroll
+ for (int i = 0; i < 4; ++i) {
+ vc[i] = va[i] == vb[i] ? 0xff : 0x00;
+ }
+ return c;
+}
+
+static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
+ const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
+ const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
+ unsigned int c;
+ uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
+#pragma unroll
+ for (int i = 0; i < 4; ++i) {
+ vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
+ }
+ return c;
+}
+
+#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
+// __shfl_xor() for half2 was added in ROCm 5.6
+static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
+ typedef union half2_b32 {
+ half2 val;
+ int b32;
+ } half2_b32_t;
+ half2_b32_t tmp;
+ tmp.val = var;
+ tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
+ return tmp.val;
+}
+#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
+#endif // defined(GGML_USE_HIPBLAS)
+
+#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
+#define FP16_AVAILABLE
+#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
+
+#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
+#define FAST_FP16_AVAILABLE
+#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
+
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+#define FP16_MMA_AVAILABLE
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
+#define INT8_MMA_AVAILABLE
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
+
+static constexpr bool fast_fp16_available(const int cc) {
+ return cc >= CC_PASCAL && cc != 610;
+}
+
+static constexpr bool fp16_mma_available(const int cc) {
+ return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
+}
+
+static constexpr bool int8_mma_available(const int cc) {
+ return cc < CC_OFFSET_AMD && cc >= CC_TURING;
+}
+
+[[noreturn]]
+static __device__ void no_device_code(
+ const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
+
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+ printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
+ file_name, line, function_name, arch);
+ GGML_UNUSED(arch_list);
+#else
+ printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
+ file_name, line, function_name, arch, arch_list);
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+ __trap();
+
+ GGML_UNUSED(no_device_code); // suppress unused function warning
+}
+
+#ifdef __CUDA_ARCH__
+#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
+#else
+#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
+#endif // __CUDA_ARCH__
+
+static __device__ __forceinline__ float warp_reduce_sum(float x) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+ }
+ return x;
+}
+
+static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
+ a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
+ }
+ return a;
+}
+
+static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
+#ifdef FP16_AVAILABLE
+
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
+ reinterpret_cast<half&>(a.x) += __low2half(a_other);
+ reinterpret_cast<half&>(a.y) += __high2half(a_other);
+ }
+ return a;
+#else
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
+ }
+ return a;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+
+#else
+ NO_DEVICE_CODE;
+ return a;
+#endif // FP16_AVAILABLE
+}
+
+static __device__ __forceinline__ float warp_reduce_max(float x) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+ }
+ return x;
+}
+
+static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
+#ifdef FP16_AVAILABLE
+
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+ return __float2half(fmaxf(__half2float(a), __half2float(b)));
+#else
+ return __hmax(a, b);
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+
+#else
+ NO_DEVICE_CODE;
+ GGML_UNUSED(b);
+ return a;
+#endif // FP16_AVAILABLE
+}
+
+static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+
+#if CUDART_VERSION >= CUDART_HMAX
+ return __hmax2(a, b);
+#else
+ half2 ret;
+ reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
+ reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
+ return ret;
+#endif // CUDART_VERSION >= CUDART_HMAX
+
+#else
+ GGML_UNUSED(a);
+ GGML_UNUSED(b);
+ NO_DEVICE_CODE;
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+}
+
+static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+ }
+ return x;
+#else
+ GGML_UNUSED(x);
+ NO_DEVICE_CODE;
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
+}
+
+#if CUDART_VERSION < CUDART_HMASK
+static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
+ const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
+ const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
+ return mask_low | mask_high;
+}
+#endif // CUDART_VERSION < 12000
+
+static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
+ c = __builtin_amdgcn_sdot4(a, b, c, false);
+#elif defined(RDNA3)
+ c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
+#elif defined(__gfx1010__) || defined(__gfx900__)
+ int tmp1;
+ int tmp2;
+ asm("\n \
+ v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
+ v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
+ v_add3_u32 %0, %1, %2, %0 \n \
+ v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
+ v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
+ v_add3_u32 %0, %1, %2, %0 \n \
+ "
+ : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
+ : "v"(a), "v"(b)
+ );
+#else
+ const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+ const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+ c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
+#endif
+ return c;
+
+#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
+ return __dp4a(a, b, c);
+#else // __CUDA_ARCH__ >= MIN_CC_DP4A
+ const int8_t * a8 = (const int8_t *) &a;
+ const int8_t * b8 = (const int8_t *) &b;
+ return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
+#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+}
+
+// TODO: move to ggml-common.h
+static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+
+typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
+
+static __device__ __forceinline__ float get_alibi_slope(
+ const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
+) {
+ if (max_bias <= 0.0f) {
+ return 1.0f;
+ }
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ return powf(base, exph);
+}
+
+static __device__ __forceinline__ float iq1bn_fp8_to_float(uint8_t fp8) {
+ typedef union { float f; uint32_t i; } scale_t;
+ scale_t s; s.i = (((fp8 >> 5) + 116) << 23) | ((fp8 & 0x1f) << 18);
+ return s.f;
+}
+
+template <ggml_type type>
+struct ggml_cuda_type_traits;
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_F16> {
+ static constexpr int qk = 1;
+ static constexpr int qr = 1;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
+ static constexpr int qk = QK4_0;
+ static constexpr int qr = QR4_0;
+ static constexpr int qi = QI4_0;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
+ static constexpr int qk = QK4_1;
+ static constexpr int qr = QR4_1;
+ static constexpr int qi = QI4_1;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
+ static constexpr int qk = QK5_0;
+ static constexpr int qr = QR5_0;
+ static constexpr int qi = QI5_0;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
+ static constexpr int qk = QK5_1;
+ static constexpr int qr = QR5_1;
+ static constexpr int qi = QI5_1;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
+ static constexpr int qk = QK8_0;
+ static constexpr int qr = QR8_0;
+ static constexpr int qi = QI8_0;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR2_K;
+ static constexpr int qi = QI2_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR3_K;
+ static constexpr int qi = QI3_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR4_K;
+ static constexpr int qi = QI4_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR5_K;
+ static constexpr int qi = QI5_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR6_K;
+ static constexpr int qi = QI6_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR2_XXS;
+ static constexpr int qi = QI2_XXS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR2_XS;
+ static constexpr int qi = QI2_XS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR2_S;
+ static constexpr int qi = QI2_S;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR3_XXS;
+ static constexpr int qi = QI3_XXS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR1_S;
+ static constexpr int qi = QI1_S;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR1_M;
+ static constexpr int qi = QI1_M;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ1_BN> {
+ static constexpr int qk = QK_IQ1BN;
+ static constexpr int qr = QR1_BN;
+ static constexpr int qi = QI1_BN;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_BN> {
+ static constexpr int qk = QK_IQ1BN;
+ static constexpr int qr = QR1_BN;
+ static constexpr int qi = QI1_BN;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
+ static constexpr int qk = QK4_NL;
+ static constexpr int qr = QR4_NL;
+ static constexpr int qi = QI4_NL;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR4_XS;
+ static constexpr int qi = QI4_XS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR3_S;
+ static constexpr int qi = QI3_S;
+};
+
+//////////////////////
+
+struct ggml_cuda_device_info {
+ int device_count;
+
+ struct cuda_device_info {
+ int cc; // compute capability
+ int nsm; // number of streaming multiprocessors
+ size_t smpb; // max. shared memory per block
+ size_t smpbo; // max. shared memory per block (with opt-in)
+ bool vmm; // virtual memory support
+ size_t vmm_granularity; // granularity of virtual memory
+ size_t total_vram;
+ };
+
+ cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
+
+ std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
+};
+
+const ggml_cuda_device_info & ggml_cuda_info();
+
+void ggml_cuda_set_device(int device);
+int ggml_cuda_get_device();
+
+struct ggml_cuda_pool {
+ virtual ~ggml_cuda_pool() = default;
+
+ virtual void * alloc(size_t size, size_t * actual_size) = 0;
+ virtual void free(void * ptr, size_t size) = 0;
+};
+
+template<typename T>
+struct ggml_cuda_pool_alloc {
+ ggml_cuda_pool * pool = nullptr;
+ T * ptr = nullptr;
+ size_t actual_size = 0;
+
+ ggml_cuda_pool_alloc() = default;
+
+ explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {
+ }
+
+ ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {
+ alloc(size);
+ }
+
+ ~ggml_cuda_pool_alloc() {
+ if (ptr != nullptr) {
+ pool->free(ptr, actual_size);
+ }
+ }
+
+ // size is in number of elements
+ T * alloc(size_t size) {
+ GGML_ASSERT(pool != nullptr);
+ GGML_ASSERT(ptr == nullptr);
+ ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
+ return ptr;
+ }
+
+ T * alloc(ggml_cuda_pool & pool, size_t size) {
+ this->pool = &pool;
+ return alloc(size);
+ }
+
+ T * get() {
+ return ptr;
+ }
+
+ ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;
+ ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;
+ ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;
+ ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
+};
+
+
+// backend interface
+
+struct ggml_tensor_extra_gpu {
+ void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
+ cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
+};
+
+
+#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
+#define USE_CUDA_GRAPH
+#endif
+
+struct ggml_graph_node_properties {
+ void * node_address;
+ ggml_op node_op;
+ int64_t ne[GGML_MAX_DIMS];
+ size_t nb[GGML_MAX_DIMS];
+ void * src_address[GGML_MAX_SRC];
+};
+
+struct ggml_cuda_graph {
+#ifdef USE_CUDA_GRAPH
+ ~ggml_cuda_graph() {
+ if (instance != nullptr) {
+ CUDA_CHECK(cudaGraphExecDestroy(instance));
+ }
+ if (graph != nullptr) {
+ CUDA_CHECK(cudaGraphDestroy(graph));
+ }
+ }
+ cudaGraph_t graph = nullptr;
+ cudaGraphExec_t instance = nullptr;
+ size_t num_nodes = 0;
+ std::vector<cudaGraphNode_t> nodes;
+ std::vector<cudaKernelNodeParams> params;
+ bool disable_due_to_gpu_arch = false;
+ bool disable_due_to_too_many_updates = false;
+ bool disable_due_to_failed_graph_capture = false;
+ int number_consecutive_updates = 0;
+ std::vector<ggml_graph_node_properties> ggml_graph_properties;
+ std::vector<char **> updated_kernel_arg;
+#endif
+};
+
+struct ggml_backend_cuda_context {
+ int device;
+ std::string name;
+ cudaEvent_t copy_event = nullptr;
+
+ cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
+ cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
+
+ std::unique_ptr<ggml_cuda_graph> cuda_graph;
+
+ explicit ggml_backend_cuda_context(int device) :
+ device(device),
+ name(GGML_CUDA_NAME + std::to_string(device)) {
+ }
+
+ ~ggml_backend_cuda_context() {
+ if (copy_event != nullptr) {
+ CUDA_CHECK(cudaEventDestroy(copy_event));
+ }
+ for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
+ for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
+ if (streams[i][j] != nullptr) {
+ CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
+ }
+ }
+ if (cublas_handles[i] != nullptr) {
+ CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
+ }
+ }
+ }
+
+ cudaStream_t stream(int device, int stream) {
+ if (streams[device][stream] == nullptr) {
+ ggml_cuda_set_device(device);
+ CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
+ }
+ return streams[device][stream];
+ }
+
+ cudaStream_t stream() {
+ return stream(device, 0);
+ }
+
+ cublasHandle_t cublas_handle(int device) {
+ if (cublas_handles[device] == nullptr) {
+ ggml_cuda_set_device(device);
+ CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));
+ CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));
+ }
+ return cublas_handles[device];
+ }
+
+ cublasHandle_t cublas_handle() {
+ return cublas_handle(device);
+ }
+
+ // pool
+ std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
+
+ static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
+
+ ggml_cuda_pool & pool(int device) {
+ if (pools[device] == nullptr) {
+ pools[device] = new_pool_for_device(device);
+ }
+ return *pools[device];
+ }
+
+ ggml_cuda_pool & pool() {
+ return pool(device);
+ }
+};
diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu
new file mode 100644
index 00000000..dac10ec3
--- /dev/null
+++ b/ggml/src/ggml-cuda/concat.cu
@@ -0,0 +1,196 @@
+#include "concat.cuh"
+
+// contiguous kernels
+static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (nidx >= ne0) {
+ return;
+ }
+
+ int offset_dst =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * gridDim.y;
+
+ if (nidx < ne00) { // src0
+ int offset_src =
+ nidx +
+ blockIdx.y * ne00 +
+ blockIdx.z * ne00 * gridDim.y;
+ dst[offset_dst] = x[offset_src];
+ } else {
+ int offset_src =
+ (nidx - ne00) +
+ blockIdx.y * (ne0 - ne00) +
+ blockIdx.z * (ne0 - ne00) * gridDim.y;
+ dst[offset_dst] = y[offset_src];
+ }
+}
+
+static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (nidx >= ne0) {
+ return;
+ }
+
+ int offset_dst =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * gridDim.y;
+
+ if (blockIdx.y < ne01) { // src0
+ int offset_src =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * ne01;
+ dst[offset_dst] = x[offset_src];
+ } else {
+ int offset_src =
+ nidx +
+ (blockIdx.y - ne01) * ne0 +
+ blockIdx.z * ne0 * (gridDim.y - ne01);
+ dst[offset_dst] = y[offset_src];
+ }
+}
+
+static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (nidx >= ne0) {
+ return;
+ }
+
+ int offset_dst =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * gridDim.y;
+
+ if (blockIdx.z < ne02) { // src0
+ int offset_src =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * gridDim.y;
+ dst[offset_dst] = x[offset_src];
+ } else {
+ int offset_src =
+ nidx +
+ blockIdx.y * ne0 +
+ (blockIdx.z - ne02) * ne0 * gridDim.y;
+ dst[offset_dst] = y[offset_src];
+ }
+}
+
+static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
+ int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
+ dim3 gridDim(num_blocks, ne1, ne2);
+ if (dim == 0) {
+ concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
+ return;
+ }
+ if (dim == 1) {
+ concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01);
+ return;
+ }
+ concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
+}
+
+// non-contiguous kernel (slow)
+static __global__ void concat_f32_non_cont(
+ const char * src0,
+ const char * src1,
+ char * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne03,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
+ int64_t /*ne10*/,
+ int64_t /*ne11*/,
+ int64_t /*ne12*/,
+ int64_t /*ne13*/,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
+ int64_t ne0,
+ int64_t /*ne1*/,
+ int64_t /*ne2*/,
+ int64_t /*ne3*/,
+ uint64_t nb0,
+ uint64_t nb1,
+ uint64_t nb2,
+ uint64_t nb3,
+ int32_t dim) {
+ const int64_t i3 = blockIdx.z;
+ const int64_t i2 = blockIdx.y;
+ const int64_t i1 = blockIdx.x;
+
+ int64_t o[4] = {0, 0, 0, 0};
+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
+
+ const float * x;
+
+ for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
+ } else {
+ x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
+ }
+
+ float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ *y = *x;
+ }
+}
+
+
+void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ cudaStream_t stream = ctx.stream();
+
+ const int32_t dim = ((int32_t *) dst->op_params)[0];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
+ const float * src0_d = (const float *)src0->data;
+ const float * src1_d = (const float *)src1->data;
+
+ float * dst_d = (float *)dst->data;
+
+ if (dim != 3) {
+ for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+ concat_f32_cuda(
+ src0_d + i3 * (src0->nb[3] / 4),
+ src1_d + i3 * (src1->nb[3] / 4),
+ dst_d + i3 * ( dst->nb[3] / 4),
+ src0->ne[0], src0->ne[1], src0->ne[2],
+ dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
+ }
+ } else {
+ const size_t size0 = ggml_nbytes(src0);
+ const size_t size1 = ggml_nbytes(src1);
+
+ CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
+ CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
+ }
+ } else {
+ dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
+ concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
+ (const char *)src0->data,
+ (const char *)src1->data,
+ ( char *)dst->data,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+ src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
+ }
+}
diff --git a/ggml/src/ggml-cuda/concat.cuh b/ggml/src/ggml-cuda/concat.cuh
new file mode 100644
index 00000000..aa506a05
--- /dev/null
+++ b/ggml/src/ggml-cuda/concat.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_CONCAT_BLOCK_SIZE 256
+
+void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/conv-transpose-1d.cu b/ggml/src/ggml-cuda/conv-transpose-1d.cu
new file mode 100644
index 00000000..b1e94d6f
--- /dev/null
+++ b/ggml/src/ggml-cuda/conv-transpose-1d.cu
@@ -0,0 +1,87 @@
+#include "conv-transpose-1d.cuh"
+
+static __global__ void conv_transpose_1d_kernel(
+ const int s0, const int p0, const int d0, const int output_size,
+ const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
+ const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
+ const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
+ const float * src0, const float * src1, float * dst) {
+ int global_index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (global_index >= output_size) {
+ return;
+ }
+
+ int out_index = global_index / dst_ne0;
+
+ float accumulator = 0;
+
+ for (int c = 0; c < src0_ne2; c++) {
+ int idx = global_index % dst_ne0;
+
+ int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
+ int input_offset = src1_ne0 * c;
+
+ for (int i = 0; i < src1_ne0; i++) {
+ if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
+ continue;
+ }
+ int weight_idx = idx - i*s0;
+
+ float kernel_weight = src0[kernel_offset + weight_idx];
+ float input_value = src1[input_offset+i];
+
+ accumulator += kernel_weight * input_value;
+ }
+ }
+ dst[global_index] = accumulator;
+}
+
+static void conv_transpose_1d_f32_f32_cuda(
+ const int s0, const int p0, const int d0, const int output_size,
+ const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
+ const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
+ const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
+ const float * src0, const float * src1, float * dst,
+ cudaStream_t stream) {
+
+ const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE;
+ conv_transpose_1d_kernel<<<num_blocks,CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 0, stream>>>(
+ s0,p0,d0,output_size,
+ src0_ne0, src0_ne1, src0_ne2, src0_ne3,
+ src1_ne0, src1_ne1, src1_ne2, src1_ne3,
+ dst_ne0, dst_ne1, dst_ne2, dst_ne3,
+ src0,src1, dst);
+}
+
+void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+
+ const ggml_tensor * src1 = dst->src[1];
+ const float * src1_d = (const float *)src1->data;
+
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+
+ const int32_t * opts = (const int32_t *)dst->op_params;
+
+ const int s0 = opts[0];
+ const int p0 = 0;//opts[3];
+ const int d0 = 1;//opts[4];
+
+ const int64_t kernel_size = ggml_nelements(src0);
+ const int64_t input_size = ggml_nelements(src1);
+ const int64_t output_size = ggml_nelements(dst);
+
+ conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ src0_d, src1_d, dst_d, stream);
+}
diff --git a/ggml/src/ggml-cuda/conv-transpose-1d.cuh b/ggml/src/ggml-cuda/conv-transpose-1d.cuh
new file mode 100644
index 00000000..6c2cf666
--- /dev/null
+++ b/ggml/src/ggml-cuda/conv-transpose-1d.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256
+
+void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
new file mode 100644
index 00000000..66e68a52
--- /dev/null
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -0,0 +1,775 @@
+//
+// Copyright (C) 2023-2024 The ggml authors
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+
+#include "convert.cuh"
+#include "dequantize.cuh"
+
+#define CUDA_Q8_0_NE_ALIGN 2048
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
+ const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
+
+ if (i >= k) {
+ return;
+ }
+
+ const int64_t ib = i/qk; // block index
+ const int64_t iqs = (i%qk)/qr; // quant index
+ const int64_t iybs = i - i%qk; // y block start index
+ const int64_t y_offset = qr == 1 ? 1 : qk/2;
+
+ // dequantize
+ dfloat2 v;
+ dequantize_kernel(vx, ib, iqs, v);
+
+ y[iybs + iqs + 0] = v.x;
+ y[iybs + iqs + y_offset] = v.y;
+}
+
+template <bool need_check>
+static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
+#if __CUDA_ARCH__ >= CC_PASCAL
+ constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
+
+ const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
+ const int * x0 = ((int *) vx) + blockIdx.x * nint;
+ half2 * y2 = (half2 *) (y + i0);
+
+ __shared__ int vals[nint];
+
+#pragma unroll
+ for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
+ if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
+ break;
+ }
+
+ const int ix = ix0 + threadIdx.x;
+ vals[ix] = x0[ix];
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
+ if (need_check && i0 + iy + 2*threadIdx.x >= k) {
+ return;
+ }
+
+ const half * b0 = ((const half *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
+ const half d = *b0;
+ const char2 qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
+
+ y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
+ }
+#else
+ GGML_UNUSED(vx);
+ GGML_UNUSED(y);
+ GGML_UNUSED(k);
+ NO_DEVICE_CODE;
+#endif // __CUDA_ARCH__ >= CC_PASCAL
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
+
+ const int64_t i = blockIdx.x;
+
+ // assume 32 threads
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8;
+ const int64_t ir = tid%8;
+ const int64_t ib = 8*i + ir;
+ if (ib >= nb32) {
+ return;
+ }
+
+ dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+ const block_q4_0 * x = (const block_q4_0 *)vx + ib;
+ const float d = __half2float(x->d);
+ const float dm = -8*d;
+
+ const uint8_t * q = x->qs + 4*il;
+
+ for (int l = 0; l < 4; ++l) {
+ y[l+ 0] = d * (q[l] & 0xF) + dm;
+ y[l+16] = d * (q[l] >> 4) + dm;
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
+
+ const int64_t i = blockIdx.x;
+
+ // assume 32 threads
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8;
+ const int64_t ir = tid%8;
+ const int64_t ib = 8*i + ir;
+ if (ib >= nb32) {
+ return;
+ }
+
+ dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+ const block_q4_1 * x = (const block_q4_1 *)vx + ib;
+ const float2 d = __half22float2(x->dm);
+
+ const uint8_t * q = x->qs + 4*il;
+
+ for (int l = 0; l < 4; ++l) {
+ y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
+ y[l+16] = d.x * (q[l] >> 4) + d.y;
+ }
+}
+
+//================================== k-quants
+
+template<typename dst_t>
+static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_q2_K * x = (const block_q2_K *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t n = tid/32;
+ const int64_t l = tid - 32*n;
+ const int64_t is = 8*n + l/16;
+
+ const uint8_t q = x[i].qs[32*n + l];
+ dst_t * y = yy + i*QK_K + 128*n;
+
+ float dall = __low2half(x[i].dm);
+ float dmin = __high2half(x[i].dm);
+ y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
+ y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
+ y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
+ y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_q3_K * x = (const block_q3_K *) vx;
+
+ const int64_t r = threadIdx.x/4;
+ const int64_t tid = r/2;
+ const int64_t is0 = r%2;
+ const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
+ const int64_t n = tid / 4;
+ const int64_t j = tid - 4*n;
+
+ uint8_t m = 1 << (4*n + j);
+ int64_t is = 8*n + 2*j + is0;
+ int shift = 2*j;
+
+ int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
+ is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
+ is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
+ (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
+ float d_all = x[i].d;
+ float dl = d_all * (us - 32);
+
+ dst_t * y = yy + i*QK_K + 128*n + 32*j;
+ const uint8_t * q = x[i].qs + 32*n;
+ const uint8_t * hm = x[i].hmask;
+
+ for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
+}
+
+static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
+ if (j < 4) {
+ d = q[j] & 63; m = q[j + 4] & 63;
+ } else {
+ d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+ m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+ const block_q4_K * x = (const block_q4_K *) vx;
+
+ const int64_t i = blockIdx.x;
+
+ // assume 32 threads
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8;
+ const int64_t ir = tid%8;
+ const int64_t is = 2*il;
+ const int64_t n = 4;
+
+ dst_t * y = yy + i*QK_K + 64*il + n*ir;
+
+ const float dall = __low2half(x[i].dm);
+ const float dmin = __high2half(x[i].dm);
+
+ const uint8_t * q = x[i].qs + 32*il + n*ir;
+
+ uint8_t sc, m;
+ get_scale_min_k4(is + 0, x[i].scales, sc, m);
+ const float d1 = dall * sc; const float m1 = dmin * m;
+ get_scale_min_k4(is + 1, x[i].scales, sc, m);
+ const float d2 = dall * sc; const float m2 = dmin * m;
+ for (int l = 0; l < n; ++l) {
+ y[l + 0] = d1 * (q[l] & 0xF) - m1;
+ y[l +32] = d2 * (q[l] >> 4) - m2;
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+ const block_q5_K * x = (const block_q5_K *) vx;
+
+ const int64_t i = blockIdx.x;
+
+ // assume 64 threads - this is very slightly better than the one below
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/16; // il is in 0...3
+ const int64_t ir = tid%16; // ir is in 0...15
+ const int64_t is = 2*il; // is is in 0...6
+
+ dst_t * y = yy + i*QK_K + 64*il + 2*ir;
+
+ const float dall = __low2half(x[i].dm);
+ const float dmin = __high2half(x[i].dm);
+
+ const uint8_t * ql = x[i].qs + 32*il + 2*ir;
+ const uint8_t * qh = x[i].qh + 2*ir;
+
+ uint8_t sc, m;
+ get_scale_min_k4(is + 0, x[i].scales, sc, m);
+ const float d1 = dall * sc; const float m1 = dmin * m;
+ get_scale_min_k4(is + 1, x[i].scales, sc, m);
+ const float d2 = dall * sc; const float m2 = dmin * m;
+
+ uint8_t hm = 1 << (2*il);
+ y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
+ y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
+ hm <<= 1;
+ y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
+ y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+ const block_q6_K * x = (const block_q6_K *) vx;
+
+ const int64_t i = blockIdx.x;
+
+ // assume 64 threads - this is very slightly better than the one below
+ const int64_t tid = threadIdx.x;
+ const int64_t ip = tid/32; // ip is 0 or 1
+ const int64_t il = tid - 32*ip; // 0...32
+ const int64_t is = 8*ip + il/16;
+
+ dst_t * y = yy + i*QK_K + 128*ip + il;
+
+ const float d = x[i].d;
+
+ const uint8_t * ql = x[i].ql + 64*ip + il;
+ const uint8_t qh = x[i].qh[32*ip + il];
+ const int8_t * sc = x[i].scales + is;
+
+ y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+ y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
+ y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+ y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint16_t * q2 = x[i].qs + 4*ib;
+ const uint8_t * aux8 = (const uint8_t *)q2;
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
+ const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq2_xs * x = (const block_iq2_xs *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint16_t * q2 = x[i].qs + 4*ib;
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
+ const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq2_s * x = (const block_iq2_s *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
+ const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint8_t * q3 = x[i].qs + 8*ib;
+ const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
+ const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
+ for (int j = 0; j < 4; ++j) {
+ y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+ y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq3_s * x = (const block_iq3_s *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint8_t * qs = x[i].qs + 8*ib;
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
+ const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
+ const uint8_t signs = x[i].signs[4*ib + il];
+ for (int j = 0; j < 4; ++j) {
+ y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+ y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq1_s * x = (const block_iq1_s *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
+ const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
+ uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
+ grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
+ grid32[0] &= 0x0f0f0f0f;
+ for (int j = 0; j < 8; ++j) {
+ y[j] = d * (q[j] + delta);
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq1_m * x = (const block_iq1_m *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
+ iq1m_scale_t scale;
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+ const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
+ const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
+ const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
+ uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
+ grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
+ grid32[0] &= 0x0f0f0f0f;
+ for (int j = 0; j < 8; ++j) {
+ y[j] = d * (q[j] + delta);
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq1_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb64) {
+
+ const int64_t ii = blockIdx.x;
+ const block_iq1_bn * x = (const block_iq1_bn *) vx;
+
+ static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
+
+//#define COMPUTE_VS(v) 3*v >> 8
+#define COMPUTE_VS(v) (v + (v >> 1)) >> 7
+
+ const int tid = threadIdx.x;
+ const int il = tid/4; // 0...7
+ const int ib = tid%4; // 0...3
+ dst_t * y = yy + ii*QK_K + 64*ib + 8*il;
+ int64_t i = QK_K/QK_IQ1BN * ii + ib;
+ if (i >= nb64) return;
+ const int i16 = il/2;
+ uint8_t q = x[i].ql[3*i16+2*(il%2)];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = COMPUTE_VS(v);
+ y[2*(il%2)+j] = vs - 1;
+ }
+ q = x[i].ql[3*i16+1];
+ for (int j = 0; j < 2; ++j) {
+ uint8_t v = k_mult[3*(il%2)+j]*q;
+ int8_t vs = COMPUTE_VS(v);
+ y[5*(1-(il%2))+j] = vs-1;
+ }
+ uint8_t v = (il%2) ? k_mult[i16]*x[i].extra : k_mult[2]*q;
+ int8_t vs = COMPUTE_VS(v);
+ y[7] = vs - 1;
+
+#undef COMPUTE_VS
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_bn(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb64) {
+
+ const int64_t ii = blockIdx.x;
+ const block_iq2_bn * x = (const block_iq2_bn *) vx;
+
+ const int64_t tid = threadIdx.x;
+ int64_t ib64 = tid%4; // 0...3
+ int64_t il = tid/4; // 0...7
+ dst_t * y = yy + 256*ii + 64*ib64 + 2*il;
+ int64_t i = 256/QK_IQ1BN * ii + ib64;
+ if (i >= nb64) return;
+ const float m = -1;
+ auto qs = x[i].qs + 2*il;
+ for (int j = 0; j < 2; ++j) {
+ y[j+ 0] = ((qs[j] >> 0) & 3) + m;
+ y[j+16] = ((qs[j] >> 2) & 3) + m;
+ y[j+32] = ((qs[j] >> 4) & 3) + m;
+ y[j+48] = ((qs[j] >> 6) & 3) + m;
+ }
+}
+
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+ const uint8_t * q4 = x[ib].qs + 4*il;
+ const float d = (float)x[ib].d;
+ for (int j = 0; j < 4; ++j) {
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+ const int64_t i = blockIdx.x;
+ const block_iq4_xs * x = (const block_iq4_xs *)vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+ const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
+ const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
+ for (int j = 0; j < 4; ++j) {
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
+ }
+}
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
+ const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
+ dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
+ if (k % CUDA_Q8_0_NE_ALIGN == 0) {
+ const bool need_check = false;
+ dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
+ } else {
+ const bool need_check = true;
+ dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
+ }
+}
+
+template<typename dst_t>
+static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb32 = k / 32;
+ const int nb = (k + 255) / 256;
+ dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
+}
+
+template<typename dst_t>
+static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb32 = k / 32;
+ const int nb = (k + 255) / 256;
+ dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
+}
+
+template<typename dst_t>
+static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = (k + QK_K - 1) / QK_K;
+ dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq1_bn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb64 = k / QK_IQ1BN;
+ const int nb = (k + 255) / 256;
+ dequantize_block_iq1_bn<<<nb, 32, 0, stream>>>(vx, y, nb64);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq2_bn_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb64 = k / QK_IQ1BN;
+ const int nb = (k + 255) / 256;
+ dequantize_block_iq2_bn<<<nb, 32, 0, stream>>>(vx, y, nb64);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = (k + QK_K - 1) / QK_K;
+ dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template <typename src_t, typename dst_t>
+static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
+ const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ const src_t * x = (src_t *) vx;
+
+ y[i] = x[i];
+}
+
+template <typename src_t, typename dst_t>
+static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+ convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+}
+
+to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ return dequantize_row_q4_0_cuda;
+ case GGML_TYPE_Q4_1:
+ return dequantize_row_q4_1_cuda;
+ case GGML_TYPE_Q5_0:
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+ case GGML_TYPE_Q5_1:
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q8_0:
+ if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
+ return dequantize_block_q8_0_f16_cuda;
+ }
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+ case GGML_TYPE_Q2_K:
+ return dequantize_row_q2_K_cuda;
+ case GGML_TYPE_Q3_K:
+ return dequantize_row_q3_K_cuda;
+ case GGML_TYPE_Q4_K:
+ return dequantize_row_q4_K_cuda;
+ case GGML_TYPE_Q5_K:
+ return dequantize_row_q5_K_cuda;
+ case GGML_TYPE_Q6_K:
+ return dequantize_row_q6_K_cuda;
+ case GGML_TYPE_IQ2_XXS:
+ return dequantize_row_iq2_xxs_cuda;
+ case GGML_TYPE_IQ2_XS:
+ return dequantize_row_iq2_xs_cuda;
+ case GGML_TYPE_IQ2_S:
+ return dequantize_row_iq2_s_cuda;
+ case GGML_TYPE_IQ3_XXS:
+ return dequantize_row_iq3_xxs_cuda;
+ case GGML_TYPE_IQ1_S:
+ return dequantize_row_iq1_s_cuda;
+ case GGML_TYPE_IQ1_M:
+ return dequantize_row_iq1_m_cuda;
+ case GGML_TYPE_IQ1_BN:
+ return dequantize_row_iq1_bn_cuda;
+ case GGML_TYPE_IQ2_BN:
+ return dequantize_row_iq2_bn_cuda;
+ case GGML_TYPE_IQ4_NL:
+ return dequantize_row_iq4_nl_cuda;
+ case GGML_TYPE_IQ4_XS:
+ return dequantize_row_iq4_xs_cuda;
+ case GGML_TYPE_IQ3_S:
+ return dequantize_row_iq3_s_cuda;
+ case GGML_TYPE_F32:
+ return convert_unary_cuda<float>;
+ default:
+ return nullptr;
+ }
+}
+
+to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ return dequantize_row_q4_0_cuda;
+ case GGML_TYPE_Q4_1:
+ return dequantize_row_q4_1_cuda;
+ case GGML_TYPE_Q5_0:
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+ case GGML_TYPE_Q5_1:
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q8_0:
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+ case GGML_TYPE_Q2_K:
+ return dequantize_row_q2_K_cuda;
+ case GGML_TYPE_Q3_K:
+ return dequantize_row_q3_K_cuda;
+ case GGML_TYPE_Q4_K:
+ return dequantize_row_q4_K_cuda;
+ case GGML_TYPE_Q5_K:
+ return dequantize_row_q5_K_cuda;
+ case GGML_TYPE_Q6_K:
+ return dequantize_row_q6_K_cuda;
+ case GGML_TYPE_IQ2_XXS:
+ return dequantize_row_iq2_xxs_cuda;
+ case GGML_TYPE_IQ2_XS:
+ return dequantize_row_iq2_xs_cuda;
+ case GGML_TYPE_IQ2_S:
+ return dequantize_row_iq2_s_cuda;
+ case GGML_TYPE_IQ3_XXS:
+ return dequantize_row_iq3_xxs_cuda;
+ case GGML_TYPE_IQ1_S:
+ return dequantize_row_iq1_s_cuda;
+ case GGML_TYPE_IQ1_M:
+ return dequantize_row_iq1_m_cuda;
+ case GGML_TYPE_IQ1_BN:
+ return dequantize_row_iq1_bn_cuda;
+ case GGML_TYPE_IQ2_BN:
+ return dequantize_row_iq2_bn_cuda;
+ case GGML_TYPE_IQ4_NL:
+ return dequantize_row_iq4_nl_cuda;
+ case GGML_TYPE_IQ4_XS:
+ return dequantize_row_iq4_xs_cuda;
+ case GGML_TYPE_IQ3_S:
+ return dequantize_row_iq3_s_cuda;
+ case GGML_TYPE_F16:
+ return convert_unary_cuda<half>;
+ default:
+ return nullptr;
+ }
+}
diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh
new file mode 100644
index 00000000..5394be9f
--- /dev/null
+++ b/ggml/src/ggml-cuda/convert.cuh
@@ -0,0 +1,13 @@
+#include "common.cuh"
+
+#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
+
+template<typename T>
+using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream);
+
+typedef to_t_cuda_t<float> to_fp32_cuda_t;
+typedef to_t_cuda_t<half> to_fp16_cuda_t;
+
+to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);
+
+to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
new file mode 100644
index 00000000..3db57034
--- /dev/null
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -0,0 +1,489 @@
+#include "cpy.cuh"
+
+typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
+
+static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ float * dsti = (float *) cdsti;
+
+ *dsti = *xi;
+}
+
+static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ half * dsti = (half *) cdsti;
+
+ *dsti = __float2half(*xi);
+}
+
+static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
+ const half * xi = (const half *) cxi;
+ half * dsti = (half *) cdsti;
+
+ *dsti = *xi;
+}
+
+static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
+ const half * xi = (const half *) cxi;
+ float * dsti = (float *) cdsti;
+
+ *dsti = *xi;
+}
+
+template <cpy_kernel_t cpy_1>
+static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+ const int nb12, const int nb13) {
+ const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= ne) {
+ return;
+ }
+
+ // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
+ // then combine those indices with the corresponding byte offsets to get the total offsets
+ const int64_t i03 = i/(ne00 * ne01 * ne02);
+ const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
+ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+ const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+ const int64_t i13 = i/(ne10 * ne11 * ne12);
+ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+ const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
+
+ cpy_1(cx + x_offset, cdst + dst_offset);
+}
+
+static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ block_q8_0 * dsti = (block_q8_0 *) cdsti;
+
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_0; j++) {
+ const float v = xi[j];
+ amax = fmaxf(amax, fabsf(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dsti->d = d;
+
+ for (int j = 0; j < QK8_0; ++j) {
+ const float x0 = xi[j]*id;
+
+ dsti->qs[j] = roundf(x0);
+ }
+}
+
+static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ block_q4_0 * dsti = (block_q4_0 *) cdsti;
+
+ float amax = 0.0f;
+ float vmax = 0.0f;
+
+ for (int j = 0; j < QK4_0; ++j) {
+ const float v = xi[j];
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ vmax = v;
+ }
+ }
+
+ const float d = vmax / -8;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dsti->d = d;
+
+ for (int j = 0; j < QK4_0/2; ++j) {
+ const float x0 = xi[0 + j]*id;
+ const float x1 = xi[QK4_0/2 + j]*id;
+
+ const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
+ const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
+
+ dsti->qs[j] = xi0;
+ dsti->qs[j] |= xi1 << 4;
+ }
+}
+
+static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ block_q4_1 * dsti = (block_q4_1 *) cdsti;
+
+ float vmin = FLT_MAX;
+ float vmax = -FLT_MAX;
+
+ for (int j = 0; j < QK4_1; ++j) {
+ const float v = xi[j];
+
+ if (v < vmin) vmin = v;
+ if (v > vmax) vmax = v;
+ }
+
+ const float d = (vmax - vmin) / ((1 << 4) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dsti->dm.x = d;
+ dsti->dm.y = vmin;
+
+ for (int j = 0; j < QK4_1/2; ++j) {
+ const float x0 = (xi[0 + j] - vmin)*id;
+ const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
+
+ const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
+ const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
+
+ dsti->qs[j] = xi0;
+ dsti->qs[j] |= xi1 << 4;
+ }
+}
+
+static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ block_q5_0 * dsti = (block_q5_0 *) cdsti;
+
+ float amax = 0.0f;
+ float vmax = 0.0f;
+
+ for (int j = 0; j < QK5_0; ++j) {
+ const float v = xi[j];
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ vmax = v;
+ }
+ }
+
+ const float d = vmax / -16;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dsti->d = d;
+
+ uint32_t qh = 0;
+ for (int j = 0; j < QK5_0/2; ++j) {
+ const float x0 = xi[0 + j]*id;
+ const float x1 = xi[QK5_0/2 + j]*id;
+
+ const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
+ const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
+
+ dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+ }
+ memcpy(dsti->qh, &qh, sizeof(qh));
+}
+
+static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ block_q5_1 * dsti = (block_q5_1 *) cdsti;
+
+ float min = xi[0];
+ float max = xi[0];
+
+ for (int j = 1; j < QK5_1; ++j) {
+ const float v = xi[j];
+ min = v < min ? v : min;
+ max = v > max ? v : max;
+ }
+
+ const float d = (max - min) / 31;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dsti->dm.x = d;
+ dsti->dm.y = min;
+
+ uint32_t qh = 0;
+ for (int j = 0; j < QK5_1/2; ++j) {
+ const float x0 = (xi[0 + j] - min)*id;
+ const float x1 = (xi[QK5_1/2 + j] - min)*id;
+
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+ dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+ }
+ memcpy(dsti->qh, &qh, sizeof(qh));
+}
+
+
+static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
+ if (x <= val[0]) return 0;
+ if (x >= val[n-1]) return n-1;
+ int ml = 0, mu = n-1;
+ while (mu-ml > 1) {
+ int mav = (ml+mu)/2;
+ if (x < val[mav]) mu = mav; else ml = mav;
+ }
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
+
+ float amax = 0.0f;
+ float vmax = 0.0f;
+
+ for (int j = 0; j < QK4_NL; ++j) {
+ const float v = xi[j];
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ vmax = v;
+ }
+ }
+
+ float d = vmax / kvalues_iq4nl[0];
+ const float id = d ? 1.0f/d : 0.0f;
+
+ float sumqx = 0, sumq2 = 0;
+ for (int j = 0; j < QK4_NL/2; ++j) {
+ const float x0 = xi[0 + j]*id;
+ const float x1 = xi[QK4_NL/2 + j]*id;
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
+ dsti->qs[j] = xi0 | (xi1 << 4);
+ const float v0 = kvalues_iq4nl[xi0];
+ const float v1 = kvalues_iq4nl[xi1];
+ const float w0 = xi[0 + j]*xi[0 + j];
+ const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
+ sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
+ sumq2 += w0*v0*v0 + w1*v1*v1;
+ }
+
+ dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
+}
+
+template <cpy_kernel_t cpy_blck, int qk>
+static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+ const int nb12, const int nb13) {
+ const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+
+ if (i >= ne) {
+ return;
+ }
+
+ const int i03 = i/(ne00 * ne01 * ne02);
+ const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+ const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
+ const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+ const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+ const int i13 = i/(ne10 * ne11 * ne12);
+ const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+ const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+ const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+ const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+
+ cpy_blck(cx + x_offset, cdst + dst_offset);
+}
+
+static void ggml_cpy_f16_f32_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+ cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_f32_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+ cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_f16_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+ cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q8_0_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK8_0 == 0);
+ const int num_blocks = ne / QK8_0;
+ cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q4_0_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK4_0 == 0);
+ const int num_blocks = ne / QK4_0;
+ cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q4_1_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK4_1 == 0);
+ const int num_blocks = ne / QK4_1;
+ cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q5_0_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK5_0 == 0);
+ const int num_blocks = ne / QK5_0;
+ cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q5_1_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK5_1 == 0);
+ const int num_blocks = ne / QK5_1;
+ cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_iq4_nl_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK4_NL == 0);
+ const int num_blocks = ne / QK4_NL;
+ cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f16_f16_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+ cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
+ const int64_t ne = ggml_nelements(src0);
+ GGML_ASSERT(ne == ggml_nelements(src1));
+
+ GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
+ GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+
+ //GGML_ASSERT(src0->ne[3] == 1);
+
+ const int64_t nb00 = src0->nb[0];
+ const int64_t nb01 = src0->nb[1];
+ const int64_t nb02 = src0->nb[2];
+ const int64_t nb03 = src0->nb[3];
+
+ const int64_t ne10 = src1->ne[0];
+ const int64_t ne11 = src1->ne[1];
+ const int64_t ne12 = src1->ne[2];
+
+ //GGML_ASSERT(src1->ne[3] == 1);
+
+ const int64_t nb10 = src1->nb[0];
+ const int64_t nb11 = src1->nb[1];
+ const int64_t nb12 = src1->nb[2];
+ const int64_t nb13 = src1->nb[3];
+
+ cudaStream_t main_stream = ctx.stream();
+
+ char * src0_ddc = (char *) src0->data;
+ char * src1_ddc = (char *) src1->data;
+
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+ ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+ ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+ ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+ ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+ ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
+ ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
+ ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
+ ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+ ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+ ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else {
+ fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
+ ggml_type_name(src0->type), ggml_type_name(src1->type));
+ GGML_ASSERT(false);
+ }
+}
+
+void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ ggml_cuda_cpy(ctx, src0, dst);
+}
+
+void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+ return (void*) cpy_f32_f16<cpy_1_f32_f32>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+ return (void*) cpy_f32_f16<cpy_1_f32_f16>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
+ return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+ return (void*) cpy_f32_f16<cpy_1_f32_f16>;
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+ return (void*) cpy_f32_f16<cpy_1_f16_f32>;
+ } else {
+ fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
+ ggml_type_name(src0->type), ggml_type_name(src1->type));
+ GGML_ASSERT(false);
+ }
+}
diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh
new file mode 100644
index 00000000..79616742
--- /dev/null
+++ b/ggml/src/ggml-cuda/cpy.cuh
@@ -0,0 +1,9 @@
+#include "common.cuh"
+
+#define CUDA_CPY_BLOCK_SIZE 32
+
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
+
+void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh
new file mode 100644
index 00000000..bd3c2d9d
--- /dev/null
+++ b/ggml/src/ggml-cuda/dequantize.cuh
@@ -0,0 +1,103 @@
+#include "common.cuh"
+
+static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+ const block_q4_0 * x = (const block_q4_0 *) vx;
+
+ const dfloat d = x[ib].d;
+
+ const int vui = x[ib].qs[iqs];
+
+ v.x = vui & 0xF;
+ v.y = vui >> 4;
+
+#ifdef GGML_CUDA_F16
+ v = __hsub2(v, {8.0f, 8.0f});
+ v = __hmul2(v, {d, d});
+#else
+ v.x = (v.x - 8.0f) * d;
+ v.y = (v.y - 8.0f) * d;
+#endif // GGML_CUDA_F16
+}
+
+static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+ const block_q4_1 * x = (const block_q4_1 *) vx;
+
+ const dfloat d = __low2half(x[ib].dm);
+ const dfloat m = __high2half(x[ib].dm);
+
+ const int vui = x[ib].qs[iqs];
+
+ v.x = vui & 0xF;
+ v.y = vui >> 4;
+
+#ifdef GGML_CUDA_F16
+ v = __hmul2(v, {d, d});
+ v = __hadd2(v, {m, m});
+#else
+ v.x = (v.x * d) + m;
+ v.y = (v.y * d) + m;
+#endif // GGML_CUDA_F16
+}
+
+static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+ const block_q5_0 * x = (const block_q5_0 *) vx;
+
+ const dfloat d = x[ib].d;
+
+ uint32_t qh;
+ memcpy(&qh, x[ib].qh, sizeof(qh));
+
+ const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
+ const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
+
+ v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
+ v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
+
+#ifdef GGML_CUDA_F16
+ v = __hsub2(v, {16.0f, 16.0f});
+ v = __hmul2(v, {d, d});
+#else
+ v.x = (v.x - 16.0f) * d;
+ v.y = (v.y - 16.0f) * d;
+#endif // GGML_CUDA_F16
+}
+
+static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+ const block_q5_1 * x = (const block_q5_1 *) vx;
+
+ const dfloat d = __low2half(x[ib].dm);
+ const dfloat m = __high2half(x[ib].dm);
+
+ uint32_t qh;
+ memcpy(&qh, x[ib].qh, sizeof(qh));
+
+ const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
+ const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
+
+ v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
+ v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
+
+#ifdef GGML_CUDA_F16
+ v = __hmul2(v, {d, d});
+ v = __hadd2(v, {m, m});
+#else
+ v.x = (v.x * d) + m;
+ v.y = (v.y * d) + m;
+#endif // GGML_CUDA_F16
+}
+
+static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+ const block_q8_0 * x = (const block_q8_0 *) vx;
+
+ const dfloat d = x[ib].d;
+
+ v.x = x[ib].qs[iqs + 0];
+ v.y = x[ib].qs[iqs + 1];
+
+#ifdef GGML_CUDA_F16
+ v = __hmul2(v, {d, d});
+#else
+ v.x *= d;
+ v.y *= d;
+#endif // GGML_CUDA_F16
+}
diff --git a/ggml/src/ggml-cuda/diagmask.cu b/ggml/src/ggml-cuda/diagmask.cu
new file mode 100644
index 00000000..4b713ba2
--- /dev/null
+++ b/ggml/src/ggml-cuda/diagmask.cu
@@ -0,0 +1,40 @@
+#include "diagmask.cuh"
+
+static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
+ const int col = blockDim.y*blockIdx.y + threadIdx.y;
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (col >= ncols) {
+ return;
+ }
+
+ const int i = row*ncols + col;
+ //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
+ //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+ dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
+}
+
+static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
+ const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
+ const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
+ const dim3 block_nums(nrows_x, block_num_x, 1);
+ diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
+}
+
+void ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int nrows0 = ggml_nrows(src0);
+
+ const int n_past = ((int32_t *) dst->op_params)[0];
+
+ diag_mask_inf_f32_cuda(src0_d, dst_d, ne00, nrows0, ne01, n_past, stream);
+}
diff --git a/ggml/src/ggml-cuda/diagmask.cuh b/ggml/src/ggml-cuda/diagmask.cuh
new file mode 100644
index 00000000..6cdbef17
--- /dev/null
+++ b/ggml/src/ggml-cuda/diagmask.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
+
+void ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu
new file mode 100644
index 00000000..174489e0
--- /dev/null
+++ b/ggml/src/ggml-cuda/dmmv.cu
@@ -0,0 +1,674 @@
+#include "dmmv.cuh"
+#include "dequantize.cuh"
+#include "convert.cuh"
+
+#ifndef K_QUANTS_PER_ITERATION
+#define K_QUANTS_PER_ITERATION 2
+#else
+static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
+#endif
+
+static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
+
+ static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
+
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ if (row > nrows) return;
+
+ const int num_blocks_per_row = ncols / QK_K;
+ const int ib0 = row*num_blocks_per_row;
+
+ const block_q2_K * x = (const block_q2_K *)vx + ib0;
+
+ float tmp = 0; // partial sum for thread in warp
+
+ const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
+ const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
+
+ const int step = 16/K_QUANTS_PER_ITERATION;
+
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const int in = tid - step*im; // 0...15 or 0...7
+
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
+ const int q_offset = 32*im + l0;
+ const int s_offset = 8*im;
+ const int y_offset = 128*im + l0;
+
+ uint32_t aux[4];
+ const uint8_t * d = (const uint8_t *)aux;
+ const uint8_t * m = (const uint8_t *)(aux + 2);
+
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+ const float * y = yy + i * QK_K + y_offset;
+ const uint8_t * q = x[i].qs + q_offset;
+
+ const float dall = __low2half(x[i].dm);
+ const float dmin = __high2half(x[i].dm);
+
+ const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
+ aux[0] = a[0] & 0x0f0f0f0f;
+ aux[1] = a[1] & 0x0f0f0f0f;
+ aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
+ aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
+
+ float sum1 = 0, sum2 = 0;
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+ sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+ + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+ + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
+ + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
+ + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
+ + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
+ + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
+ +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
+ sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
+ + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
+
+ }
+ tmp += dall * sum1 - dmin * sum2;
+
+ }
+
+ // sum up partial sums and write back result
+ tmp = warp_reduce_sum(tmp);
+
+ if (threadIdx.x == 0) {
+ dst[row] = tmp;
+ }
+}
+
+static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
+
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ if (row > nrows) return;
+
+ const int num_blocks_per_row = ncols / QK_K;
+ const int ib0 = row*num_blocks_per_row;
+
+ const block_q3_K * x = (const block_q3_K *)vx + ib0;
+
+ float tmp = 0; // partial sum for thread in warp
+
+ const uint16_t kmask1 = 0x0303;
+ const uint16_t kmask2 = 0x0f0f;
+
+ const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
+
+ const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
+ const int step = 16/K_QUANTS_PER_ITERATION;
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const int in = tid - step*im; // 0....15 or 0...7
+
+ const uint8_t m = 1 << (4*im);
+
+ const int l0 = n*in; // 0...15 or 0...14 in steps of 2
+ const int q_offset = 32*im + l0;
+ const int y_offset = 128*im + l0;
+
+ uint16_t utmp[4];
+ const int8_t * s = (const int8_t *)utmp;
+
+ const uint16_t s_shift = 4*im;
+
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+ const float * y = yy + i * QK_K + y_offset;
+ const uint8_t * q = x[i].qs + q_offset;
+ const uint8_t * h = x[i].hmask + l0;
+
+ const uint16_t * a = (const uint16_t *)x[i].scales;
+ utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
+ utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
+ utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
+ utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
+
+ const float d = x[i].d;
+
+ float sum = 0;
+ for (int l = 0; l < n; ++l) {
+ sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
+ + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
+ + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
+ + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
+ sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
+ + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
+ + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
+ + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
+ }
+ tmp += d * sum;
+
+ }
+
+ // sum up partial sums and write back result
+ tmp = warp_reduce_sum(tmp);
+
+ if (threadIdx.x == 0) {
+ dst[row] = tmp;
+ }
+}
+
+static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
+
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ if (row > nrows) return;
+ const int num_blocks_per_row = ncols / QK_K;
+ const int ib0 = row*num_blocks_per_row;
+
+ const block_q4_K * x = (const block_q4_K *)vx + ib0;
+
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
+ const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
+
+ const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
+
+ const int il = tid/step; // 0...3
+ const int ir = tid - step*il; // 0...7 or 0...3
+ const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
+
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+ const int in = il%2;
+
+ const int l0 = n*(2*ir + in);
+ const int q_offset = 32*im + l0;
+ const int y_offset = 64*im + l0;
+
+ uint16_t aux[4];
+ const uint8_t * sc = (const uint8_t *)aux;
+
+#if K_QUANTS_PER_ITERATION == 2
+ uint32_t q32[4];
+ const uint8_t * q4 = (const uint8_t *)q32;
+#else
+ uint16_t q16[4];
+ const uint8_t * q4 = (const uint8_t *)q16;
+#endif
+
+ float tmp = 0; // partial sum for thread in warp
+
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+ const float * y1 = yy + i*QK_K + y_offset;
+ const float * y2 = y1 + 128;
+
+ const float dall = __low2half(x[i].dm);
+ const float dmin = __high2half(x[i].dm);
+
+ const uint16_t * a = (const uint16_t *)x[i].scales;
+ aux[0] = a[im+0] & kmask1;
+ aux[1] = a[im+2] & kmask1;
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+#if K_QUANTS_PER_ITERATION == 2
+ const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
+ const uint32_t * q2 = q1 + 16;
+
+ q32[0] = q1[0] & 0x0f0f0f0f;
+ q32[1] = q1[0] & 0xf0f0f0f0;
+ q32[2] = q2[0] & 0x0f0f0f0f;
+ q32[3] = q2[0] & 0xf0f0f0f0;
+
+ float4 s = {0.f, 0.f, 0.f, 0.f};
+ float smin = 0;
+ for (int l = 0; l < 4; ++l) {
+ s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4];
+ s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12];
+ smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+ }
+ tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
+#else
+ const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
+ const uint16_t * q2 = q1 + 32;
+
+ q16[0] = q1[0] & 0x0f0f;
+ q16[1] = q1[0] & 0xf0f0;
+ q16[2] = q2[0] & 0x0f0f;
+ q16[3] = q2[0] & 0xf0f0;
+
+ float4 s = {0.f, 0.f, 0.f, 0.f};
+ float smin = 0;
+ for (int l = 0; l < 2; ++l) {
+ s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
+ s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
+ smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+ }
+ tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
+#endif
+
+ }
+
+ // sum up partial sums and write back result
+ tmp = warp_reduce_sum(tmp);
+
+ if (tid == 0) {
+ dst[row] = tmp;
+ }
+}
+
+static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {
+
+ const int row = blockIdx.x;
+ const int num_blocks_per_row = ncols / QK_K;
+ const int ib0 = row*num_blocks_per_row;
+
+ const block_q5_K * x = (const block_q5_K *)vx + ib0;
+
+ float tmp = 0; // partial sum for thread in warp
+
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
+ const int tid = threadIdx.x/2; // 0...15
+ const int ix = threadIdx.x%2;
+
+ const int il = tid/4; // 0...3
+ const int ir = tid - 4*il;// 0...3
+ const int n = 2;
+
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+ const int in = il%2;
+
+ const int l0 = n*(2*ir + in);
+ const int q_offset = 32*im + l0;
+ const int y_offset = 64*im + l0;
+
+ const uint8_t hm1 = 1 << (2*im);
+ const uint8_t hm2 = hm1 << 4;
+
+ uint16_t aux[4];
+ const uint8_t * sc = (const uint8_t *)aux;
+
+ uint16_t q16[8];
+ const uint8_t * q4 = (const uint8_t *)q16;
+
+ for (int i = ix; i < num_blocks_per_row; i += 2) {
+
+ const uint8_t * ql1 = x[i].qs + q_offset;
+ const uint8_t * qh = x[i].qh + l0;
+ const float * y1 = yy + i*QK_K + y_offset;
+ const float * y2 = y1 + 128;
+
+ const float dall = __low2half(x[i].dm);
+ const float dmin = __high2half(x[i].dm);
+
+ const uint16_t * a = (const uint16_t *)x[i].scales;
+ aux[0] = a[im+0] & kmask1;
+ aux[1] = a[im+2] & kmask1;
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+ float4 sum = {0.f, 0.f, 0.f, 0.f};
+ float smin = 0;
+ const uint16_t * q1 = (const uint16_t *)ql1;
+ const uint16_t * q2 = q1 + 32;
+ q16[0] = q1[0] & 0x0f0f;
+ q16[1] = q1[8] & 0x0f0f;
+ q16[2] = (q1[0] >> 4) & 0x0f0f;
+ q16[3] = (q1[8] >> 4) & 0x0f0f;
+ q16[4] = q2[0] & 0x0f0f;
+ q16[5] = q2[8] & 0x0f0f;
+ q16[6] = (q2[0] >> 4) & 0x0f0f;
+ q16[7] = (q2[8] >> 4) & 0x0f0f;
+ for (int l = 0; l < n; ++l) {
+ sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+ + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0));
+ sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+ + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0));
+ sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+ + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0));
+ sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+ + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0));
+ smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+ + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
+ }
+ tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
+ }
+
+ // sum up partial sums and write back result
+ tmp = warp_reduce_sum(tmp);
+
+ if (threadIdx.x == 0) {
+ dst[row] = tmp;
+ }
+}
+
+static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
+
+ static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
+
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ if (row > nrows) return;
+
+ const int num_blocks_per_row = ncols / QK_K;
+ const int ib0 = row*num_blocks_per_row;
+
+ const block_q6_K * x = (const block_q6_K *)vx + ib0;
+
+ const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
+
+ const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
+
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const int in = tid - step*im; // 0...15 or 0...7
+
+#if K_QUANTS_PER_ITERATION == 1
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
+ const int is = 0;
+#else
+ const int l0 = 4 * in; // 0, 4, 8, ..., 28
+ const int is = in / 4;
+#endif
+ const int ql_offset = 64*im + l0;
+ const int qh_offset = 32*im + l0;
+ const int s_offset = 8*im + is;
+ const int y_offset = 128*im + l0;
+
+ float tmp = 0; // partial sum for thread in warp
+
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+ const float * y = yy + i * QK_K + y_offset;
+ const uint8_t * ql = x[i].ql + ql_offset;
+ const uint8_t * qh = x[i].qh + qh_offset;
+ const int8_t * s = x[i].scales + s_offset;
+
+ const float d = x[i].d;
+
+#if K_QUANTS_PER_ITERATION == 1
+ float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+ + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+ + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+ + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+ + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+ + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+ + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+ +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
+ tmp += sum;
+#else
+ float sum = 0;
+ for (int l = 0; l < 4; ++l) {
+ sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+ + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+ + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+ + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
+ }
+ tmp += sum;
+#endif
+
+ }
+
+ // sum up partial sums and write back result
+ tmp = warp_reduce_sum(tmp);
+
+ if (tid == 0) {
+ dst[row] = tmp;
+ }
+}
+
+static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+ const half * x = (const half *) vx;
+
+ // automatic half -> float type cast if dfloat == float
+ v.x = x[ib + iqs + 0];
+ v.y = x[ib + iqs + 1];
+}
+
+static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
+ return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 :
+ type == GGML_TYPE_Q4_1 ? dequantize_q4_1 :
+ type == GGML_TYPE_Q5_0 ? dequantize_q5_0 :
+ type == GGML_TYPE_Q5_1 ? dequantize_q5_1 :
+ type == GGML_TYPE_Q8_0 ? dequantize_q8_0 :
+ type == GGML_TYPE_F16 ? convert_f16 :
+ nullptr;
+}
+
+template <ggml_type type>
+static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
+ constexpr int qk = ggml_cuda_type_traits<type>::qk; // quantized weights per x block
+ constexpr int qr = ggml_cuda_type_traits<type>::qr; // number of quantized weights per data value in x block
+ constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type);
+
+ const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
+
+ if (row >= nrows) {
+ return;
+ }
+
+ const int tid = threadIdx.x;
+
+ const int iter_stride = 2*GGML_CUDA_DMMV_X;
+ const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
+ const int y_offset = qr == 1 ? 1 : qk/2;
+
+// partial sum for each thread
+#ifdef GGML_CUDA_F16
+ half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
+#else
+ float tmp = 0.0f;
+#endif // GGML_CUDA_F16
+
+ for (int i = 0; i < ncols; i += iter_stride) {
+ const int col = i + vals_per_iter*tid;
+ const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index
+ const int iqs = (col%qk)/qr; // x quant index
+ const int iybs = col - col%qk; // y block start index
+
+// processing >2 values per i iter is faster for fast GPUs
+#pragma unroll
+ for (int j = 0; j < vals_per_iter; j += 2) {
+ // process 2 vals per j iter
+
+ // dequantize
+ // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
+ dfloat2 v;
+ dequantize_kernel(vx, ib, iqs + j/qr, v);
+
+ // matrix multiplication
+ // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
+#ifdef GGML_CUDA_F16
+ tmp += __hmul2(v, {
+ y[iybs + iqs + j/qr + 0],
+ y[iybs + iqs + j/qr + y_offset]
+ });
+#else
+ tmp += v.x * y[iybs + iqs + j/qr + 0];
+ tmp += v.y * y[iybs + iqs + j/qr + y_offset];
+#endif // GGML_CUDA_F16
+ }
+ }
+
+ // sum up partial sums and write back result
+ tmp = warp_reduce_sum(tmp);
+
+ if (tid == 0) {
+#ifdef GGML_CUDA_F16
+ dst[row] = tmp.x + tmp.y;
+#else
+ dst[row] = tmp;
+#endif // GGML_CUDA_F16
+ }
+}
+
+static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+ // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+ dequantize_mul_mat_vec<GGML_TYPE_Q4_0>
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+ dequantize_mul_mat_vec<GGML_TYPE_Q4_1>
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+ dequantize_mul_mat_vec<GGML_TYPE_Q5_0>
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+ dequantize_mul_mat_vec<GGML_TYPE_Q5_1>
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+ dequantize_mul_mat_vec<GGML_TYPE_Q8_0>
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(32, ny, 1);
+ dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(32, ny, 1);
+ dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(32, ny, 1);
+ dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const dim3 block_dims(32, 1, 1);
+ dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
+}
+
+static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(32, ny, 1);
+ dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+ const dim3 block_nums(block_num_y, 1, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+ dequantize_mul_mat_vec<GGML_TYPE_F16>
+ <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
+void ggml_cuda_op_dequantize_mul_mat_vec(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream) {
+ GGML_UNUSED(ctx);
+ const int64_t ne00 = src0->ne[0];
+ const int64_t row_diff = row_high - row_low;
+
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
+#ifdef GGML_CUDA_F16
+ ggml_cuda_pool_alloc<half> src1_dfloat_a(ctx.pool());
+ half * src1_dfloat = nullptr; // dfloat == half
+
+ bool src1_convert_f16 =
+ src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+ src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
+ src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+
+ if (src1_convert_f16) {
+ src1_dfloat = src1_dfloat_a.alloc(ne00);
+ const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+ GGML_ASSERT(to_fp16_cuda != nullptr);
+ to_fp16_cuda(src1_ddf_i, src1_dfloat, ne00, stream);
+ }
+#else
+ const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
+#endif // GGML_CUDA_F16
+
+ switch (src0->type) {
+ case GGML_TYPE_Q4_0:
+ dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q4_1:
+ dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_0:
+ dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_1:
+ dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q8_0:
+ dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q2_K:
+ dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q3_K:
+ dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q4_K:
+ dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_K:
+ dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q6_K:
+ dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_F16:
+ convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+
+ GGML_UNUSED(src1);
+ GGML_UNUSED(dst);
+ GGML_UNUSED(src1_ddq_i);
+ GGML_UNUSED(src1_ncols);
+ GGML_UNUSED(src1_padded_row_size);
+}
diff --git a/ggml/src/ggml-cuda/dmmv.cuh b/ggml/src/ggml-cuda/dmmv.cuh
new file mode 100644
index 00000000..4c5ebd47
--- /dev/null
+++ b/ggml/src/ggml-cuda/dmmv.cuh
@@ -0,0 +1,18 @@
+#include "common.cuh"
+
+// dmmv = dequantize_mul_mat_vec
+
+// TODO: remove this?
+#ifndef GGML_CUDA_DMMV_X
+#define GGML_CUDA_DMMV_X 32
+#endif
+
+#ifndef GGML_CUDA_MMV_Y
+#define GGML_CUDA_MMV_Y 1
+#endif
+
+void ggml_cuda_op_dequantize_mul_mat_vec(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream);
diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh
new file mode 100644
index 00000000..f24312dd
--- /dev/null
+++ b/ggml/src/ggml-cuda/fattn-common.cuh
@@ -0,0 +1,701 @@
+#pragma once
+
+#include "common.cuh"
+#include "convert.cuh"
+#include "vecdotq.cuh"
+
+#include <cstdint>
+
+#define FATTN_KQ_STRIDE 256
+#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
+#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
+
+typedef void (* fattn_kernel_t)(
+ const char * __restrict__ Q,
+ const char * __restrict__ K,
+ const char * __restrict__ V,
+ const char * __restrict__ mask,
+ float * __restrict__ dst,
+ float2 * __restrict__ dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int ne03,
+ const int ne10,
+ const int ne11,
+ const int ne12,
+ const int ne13,
+ const int ne31,
+ const int nb31,
+ const int nb01,
+ const int nb02,
+ const int nb03,
+ const int nb11,
+ const int nb12,
+ const int nb13,
+ const int nb21,
+ const int nb22,
+ const int nb23,
+ const int ne0,
+ const int ne1,
+ const int ne2,
+ const int ne3);
+
+typedef half (*vec_dot_KQ_f16_t)(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
+typedef float (*vec_dot_KQ_f32_t)(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
+
+template<typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+ const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ T sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ const int k_KQ = k_KQ_0 + threadIdx.x;
+
+ const int ib = k_KQ / QI8_1;
+ const int iqs4 = k_KQ % QI4_0;
+ const int shift = k_KQ & (QI8_1/2);
+
+ const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+ const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+ const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+#ifdef FP16_AVAILABLE
+ if (std::is_same<T, half>::value) {
+ const half2 * Q_ds = (const half2 *) Q_ds_v;
+
+ const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
+ sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
+ } else
+#endif // FP16_AVAILABLE
+ {
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+ sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
+ }
+ }
+
+ return sum;
+}
+
+template<typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+ const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ T sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ const int k_KQ = k_KQ_0 + threadIdx.x;
+
+ const int ib = k_KQ / QI8_1;
+ const int iqs4 = k_KQ % QI4_1;
+ const int shift = k_KQ & (QI8_1/2);
+
+ const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+ const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+ const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+#ifdef FP16_AVAILABLE
+ if (std::is_same<T, half>::value) {
+ const half2 * Q_ds = (const half2 *) Q_ds_v;
+
+ const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
+ const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
+ sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
+ } else
+#endif // FP16_AVAILABLE
+ {
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+ const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
+ const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
+
+ sum += (T) (sumid4d8 + m4s8scaled);
+ }
+ }
+
+ return sum;
+}
+
+template<typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+ const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ T sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ const int k_KQ = k_KQ_0 + threadIdx.x;
+
+ const int ib = k_KQ / QI8_1;
+ const int iqs4 = k_KQ % QI5_0;
+ const int iqs8 = k_KQ % QI8_1;
+ const int shift = k_KQ & (QI8_1/2);
+
+ int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+ const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
+ v |= (vh << 4) & 0x00000010; // 0 -> 4
+ v |= (vh << 11) & 0x00001000; // 1 -> 12
+ v |= (vh << 18) & 0x00100000; // 2 -> 20
+ v |= (vh << 25) & 0x10000000; // 3 -> 28
+
+ const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+ const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+#ifdef FP16_AVAILABLE
+ if (std::is_same<T, half>::value) {
+ const half2 * Q_ds = (const half2 *) Q_ds_v;
+
+ const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
+ sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
+ } else
+#endif // FP16_AVAILABLE
+ {
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+ sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
+ }
+ }
+
+ return sum;
+}
+
+template<typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+ const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ T sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ const int k_KQ = k_KQ_0 + threadIdx.x;
+
+ const int ib = k_KQ / QI8_1;
+ const int iqs4 = k_KQ % QI5_1;
+ const int iqs8 = k_KQ % QI8_1;
+ const int shift = k_KQ & (QI8_1/2);
+
+ int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+ const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
+ v |= (vh << 4) & 0x00000010; // 0 -> 4
+ v |= (vh << 11) & 0x00001000; // 1 -> 12
+ v |= (vh << 18) & 0x00100000; // 2 -> 20
+ v |= (vh << 25) & 0x10000000; // 3 -> 28
+
+ const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+ const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+#ifdef FP16_AVAILABLE
+ if (std::is_same<T, half>::value) {
+ const half2 * Q_ds = (const half2 *) Q_ds_v;
+
+ const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
+ const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
+ sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
+ } else
+#endif // FP16_AVAILABLE
+ {
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+ const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
+ const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
+
+ sum += (T) (sumid5d8 + m5s8scaled);
+ }
+ }
+
+ return sum;
+}
+
+template <typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+ const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ T sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ const int k_KQ = k_KQ_0 + threadIdx.x;
+
+ const int ib = k_KQ / QI8_0;
+ const int iqs = k_KQ % QI8_0;
+
+ const int v = get_int_b2(K_q8_0[ib].qs, iqs);
+
+ T Q_d;
+ if (std::is_same<T, half>::value) {
+ const half2 * Q_ds = (const half2 *) Q_ds_v;
+ Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]);
+ } else {
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
+ Q_d = Q_ds[k_KQ_0/WARP_SIZE].x;
+ }
+
+ sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d);
+ }
+
+ return sum;
+}
+
+template <typename T, int D>
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
+
+ const half2 * K_h2 = (const half2 *) K_c;
+ GGML_UNUSED(Q_q8);
+ GGML_UNUSED(Q_ds_v);
+
+#ifdef FP16_AVAILABLE
+ if (std::is_same<T, half>::value) {
+ const half2 * Q_h2 = (const half2 *) Q_v;
+
+ half2 sum2 = make_half2(0.0f, 0.0f);
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+ const int k_KQ = k_KQ_0 + threadIdx.x;
+
+ const half2 K_ik = K_h2[k_KQ];
+ sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
+ }
+
+ return __low2half(sum2) + __high2half(sum2);
+ }
+#endif // FP16_AVAILABLE
+
+ const float2 * Q_f2 = (const float2 *) Q_v;
+
+ float sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+ const int k_KQ = k_KQ_0 + threadIdx.x;
+
+ const half2 K_ik = K_h2[k_KQ];
+ sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x;
+ sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y;
+ }
+
+ return sum;
+}
+
+template <typename Tds>
+static __device__ __forceinline__ void quantize_q8_1_to_shared(
+ const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
+
+ float vals[sizeof(int)] = {0.0f};
+#pragma unroll
+ for (int l = 0; l < sizeof(int); ++l) {
+ vals[l] = scale * x[4*threadIdx.x + l];
+ }
+
+ float amax = fabsf(vals[0]);
+ float sum = vals[0];
+#pragma unroll
+ for (int l = 1; l < sizeof(int); ++l) {
+ amax = fmaxf(amax, fabsf(vals[l]));
+ sum += vals[l];
+ }
+#pragma unroll
+ for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
+ amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32));
+ sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32);
+ }
+
+ const float d = amax / 127;
+ int q32 = 0;
+ int8_t * q8 = (int8_t *) &q32;
+
+ if (d != 0.0f) {
+#pragma unroll
+ for (int l = 0; l < sizeof(int); ++l) {
+ q8[l] = roundf(vals[l] / d);
+ }
+ }
+
+ yq32[threadIdx.x] = q32;
+ if (threadIdx.x % QI8_1 == 0) {
+ if (std::is_same<Tds, half2>::value) {
+ ((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum);
+ } else {
+ ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum);
+ }
+ }
+}
+
+typedef half (*dequantize_1_f16_t)(const void *, const int64_t);
+typedef float (*dequantize_1_f32_t)(const void *, const int64_t);
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) {
+ const block_q4_0 * x = (const block_q4_0 *) vx;
+
+ const int64_t ib = i / QK4_0;
+ const int iqs = i % (QK4_0/2);
+ const int shift = (i % QK4_0) / (QK4_0/2);
+
+ const T d = x[ib].d;
+ const int q0 = x[ib].qs[iqs];
+ const int q = ((q0 >> (4*shift)) & 0x0F) - 8;
+
+#ifdef FP16_AVAILABLE
+ if (std::is_same<T, half>::value) {
+ return ((half) d)*((half) q);
+ }
+#endif // FP16_AVAILABLE
+
+ return ((float) d)*((float) q);
+}
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) {
+ const block_q4_1 * x = (const block_q4_1 *) vx;
+
+ const int64_t ib = i / QK4_1;
+ const int iqs = i % (QK4_1/2);
+ const int shift = (i % QK4_1) / (QK4_1/2);
+
+ const half2 dm = x[ib].dm;
+ const int q0 = x[ib].qs[iqs];
+ const int q = ((q0 >> (4*shift)) & 0x0F);
+
+#ifdef FP16_AVAILABLE
+ if (std::is_same<T, half>::value) {
+ return __low2half(dm)*((half) q) + __high2half(dm);
+ }
+#endif // FP16_AVAILABLE
+
+ return __low2float(dm)*((float) q) + __high2float(dm);
+}
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) {
+ const block_q5_0 * x = (const block_q5_0 *) vx;
+
+ const int64_t ib = i / QK5_0;
+ const int idq = i % QK5_0;
+ const int iqs = i % (QK5_0/2);
+ const int shift = (i % QK5_0) / (QK5_0/2);
+
+ const T d = x[ib].d;
+ const int ql0 = x[ib].qs[iqs];
+ const int qh0 = get_int_b2(x[ib].qh, 0);
+ const int ql = ((ql0 >> (4*shift)) & 0x0F);
+ const int qh = ((qh0 >> idq) << 4) & 0x10;
+ const int q = (ql | qh) - 16;
+
+#ifdef FP16_AVAILABLE
+ if (std::is_same<T, half>::value) {
+ return ((half) d)*((half) q);
+ }
+#endif // FP16_AVAILABLE
+
+ return ((float) d)*((float) q);
+}
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) {
+ const block_q5_1 * x = (const block_q5_1 *) vx;
+
+ const int64_t ib = i / QK5_1;
+ const int idq = i % QK5_1;
+ const int iqs = i % (QK5_1/2);
+ const int shift = (i % QK5_1) / (QK5_1/2);
+
+ const half2 dm = x[ib].dm;
+ const int ql0 = x[ib].qs[iqs];
+ const int qh0 = get_int_b4(x[ib].qh, 0);
+ const int ql = ((ql0 >> (4*shift)) & 0x0F);
+ const int qh = ((qh0 >> idq) << 4) & 0x10;
+ const int q = (ql | qh);
+
+#ifdef FP16_AVAILABLE
+ if (std::is_same<T, half>::value) {
+ return __low2half(dm)*((half) q) + __high2half(dm);
+ }
+#endif // FP16_AVAILABLE
+
+ return __low2float(dm)*((float) q) + __high2float(dm);
+}
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) {
+ const block_q8_0 * x = (const block_q8_0 *) vx;
+
+ const int64_t ib = i / QK8_0;
+ const int iqs = i % QK8_0;
+
+ const T d = x[ib].d;
+ const int q = x[ib].qs[iqs];
+
+#ifdef FP16_AVAILABLE
+ if (std::is_same<T, half>::value) {
+ return ((half) d)*((half) q);
+ }
+#endif // FP16_AVAILABLE
+
+ return ((float) d)*((float) q);
+}
+
+template <typename T>
+static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) {
+ const half * x = (const half *) vx;
+
+ return x[i];
+}
+
+template <int D>
+constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
+ nullptr;
+}
+
+template <int D>
+constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
+ nullptr;
+}
+
+constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
+ return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<half> :
+ type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<half> :
+ type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<half> :
+ type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<half> :
+ type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<half> :
+ type_V == GGML_TYPE_F16 ? dequantize_1_f16<half> :
+ nullptr;
+}
+
+constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
+ return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0<float> :
+ type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1<float> :
+ type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0<float> :
+ type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1<float> :
+ type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0<float> :
+ type_V == GGML_TYPE_F16 ? dequantize_1_f16<float> :
+ nullptr;
+}
+
+template<int D, int parallel_blocks> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_combine_results(
+ const float * __restrict__ VKQ_parts,
+ const float2 * __restrict__ VKQ_meta,
+ float * __restrict__ dst) {
+ VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
+ VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
+ dst += D * gridDim.y*blockIdx.x;
+
+ const int tid = threadIdx.x;
+ __builtin_assume(tid < D);
+
+ __shared__ float2 meta[parallel_blocks];
+ if (tid < 2*parallel_blocks) {
+ ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
+ }
+
+ __syncthreads();
+
+ float kqmax = meta[0].x;
+#pragma unroll
+ for (int l = 1; l < parallel_blocks; ++l) {
+ kqmax = max(kqmax, meta[l].x);
+ }
+
+ float VKQ_numerator = 0.0f;
+ float VKQ_denominator = 0.0f;
+#pragma unroll
+ for (int l = 0; l < parallel_blocks; ++l) {
+ const float diff = meta[l].x - kqmax;
+ const float KQ_max_scale = expf(diff);
+ const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
+ *((uint32_t *) &KQ_max_scale) &= ftz_mask;
+
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
+ VKQ_denominator += KQ_max_scale * meta[l].y;
+ }
+
+ dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
+}
+
+static void on_no_fattn_vec_case(const int D) {
+ if (D == 64) {
+ fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
+ fprintf(stderr, "By default only f16 KV cache is supported.\n");
+ fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
+ GGML_ASSERT(false);
+ } else if (D == 128) {
+ fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
+ fprintf(stderr, "Supported combinations:\n");
+ fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n");
+ fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n");
+ fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n");
+ fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
+ GGML_ASSERT(false);
+ } else {
+ fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
+ fprintf(stderr, "Only f16 is supported.\n");
+ GGML_ASSERT(false);
+ }
+}
+
+template <int D, int parallel_blocks>
+void launch_fattn(
+ ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
+ const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
+) {
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+
+ const ggml_tensor * mask = dst->src[3];
+
+ ggml_tensor * KQV = dst;
+
+ GGML_ASSERT(Q->type == GGML_TYPE_F32);
+ GGML_ASSERT(KQV->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
+ GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
+ "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
+
+ GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
+
+ ggml_cuda_pool & pool = ctx.pool();
+ cudaStream_t main_stream = ctx.stream();
+
+ ggml_cuda_pool_alloc<half> K_f16(pool);
+ ggml_cuda_pool_alloc<half> V_f16(pool);
+ ggml_cuda_pool_alloc<float> dst_tmp(pool);
+ ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
+
+ char * K_data = (char *) K->data;
+ size_t nb11 = K->nb[1];
+ size_t nb12 = K->nb[2];
+ size_t nb13 = K->nb[3];
+
+ char * V_data = (char *) V->data;
+ size_t nb21 = V->nb[1];
+ size_t nb22 = V->nb[2];
+ size_t nb23 = V->nb[3];
+
+ if (need_f16_K && K->type != GGML_TYPE_F16) {
+ K_f16.alloc(ggml_nelements(K));
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
+ to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
+ K_data = (char *) K_f16.ptr;
+
+ const size_t bs = ggml_blck_size(K->type);
+ const size_t ts = ggml_type_size(K->type);
+
+ nb11 = nb11*bs*sizeof(half)/ts;
+ nb12 = nb12*bs*sizeof(half)/ts;
+ nb13 = nb13*bs*sizeof(half)/ts;
+ }
+
+ if (need_f16_V && V->type != GGML_TYPE_F16) {
+ V_f16.alloc(ggml_nelements(V));
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
+ to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+ V_data = (char *) V_f16.ptr;
+
+ const size_t bs = ggml_blck_size(V->type);
+ const size_t ts = ggml_type_size(V->type);
+
+ nb21 = nb21*bs*sizeof(half)/ts;
+ nb22 = nb22*bs*sizeof(half)/ts;
+ nb23 = nb23*bs*sizeof(half)/ts;
+ }
+
+ if (parallel_blocks > 1) {
+ dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+ dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+ }
+
+ const dim3 block_dim(WARP_SIZE, nwarps, 1);
+ const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
+ const int shmem = 0;
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
+
+ const uint32_t n_head = Q->ne[2];
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
+ (const char *) Q->data,
+ K_data,
+ V_data,
+ mask ? ((const char *) mask->data) : nullptr,
+ (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
+ scale, max_bias, m0, m1, n_head_log2,
+ Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3],
+ mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
+ Q->nb[1], Q->nb[2], Q->nb[3],
+ nb11, nb12, nb13,
+ nb21, nb22, nb23,
+ KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
+ );
+ CUDA_CHECK(cudaGetLastError());
+
+ if ((parallel_blocks) == 1) {
+ return;
+ }
+
+ const dim3 block_dim_combine(D, 1, 1);
+ const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
+ const int shmem_combine = 0;
+
+ flash_attn_combine_results<D, parallel_blocks>
+ <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
+ (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
+ CUDA_CHECK(cudaGetLastError());
+}
diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu
new file mode 100644
index 00000000..c6c35134
--- /dev/null
+++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu
@@ -0,0 +1,319 @@
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-tile-f16.cuh"
+
+#define FATTN_KQ_STRIDE_TILE_F16 64
+
+template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_tile_ext_f16(
+ const char * __restrict__ Q,
+ const char * __restrict__ K,
+ const char * __restrict__ V,
+ const char * __restrict__ mask,
+ float * __restrict__ dst,
+ float2 * __restrict__ dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int ne03,
+ const int ne10,
+ const int ne11,
+ const int ne12,
+ const int ne13,
+ const int ne31,
+ const int nb31,
+ const int nb01,
+ const int nb02,
+ const int nb03,
+ const int nb11,
+ const int nb12,
+ const int nb13,
+ const int nb21,
+ const int nb22,
+ const int nb23,
+ const int ne0,
+ const int ne1,
+ const int ne2,
+ const int ne3) {
+#ifdef FP16_AVAILABLE
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+ const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
+ const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
+ const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+ const half * maskh = (const half *) mask + ne11*ic0;
+
+ const int stride_KV2 = nb11 / sizeof(half2);
+
+ const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+ const half slopeh = __float2half(slopef);
+
+ static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+
+ __shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16];
+ half2 * KQ2 = (half2 *) KQ;
+
+ __shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts.
+
+ half kqmax[ncols/nwarps];
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ kqmax[j0/nwarps] = -HALF_MAX_HALF;
+ }
+ half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}};
+
+ half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
+
+ // Convert Q to half2 and store in registers:
+ __shared__ half2 Q_h2[ncols][D/2];
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
+ Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
+ }
+ }
+
+ __syncthreads();
+
+ const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
+ for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
+ // Calculate KQ tile and keep track of new maximum KQ values:
+
+ half kqmax_new[ncols/nwarps];
+#pragma unroll
+ for (int j = 0; j < ncols/nwarps; ++j) {
+ kqmax_new[j] = kqmax[j];
+ }
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) {
+ const int i_KQ = i_KQ_0 + threadIdx.y;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+ const int k_KQ = k_KQ_0 + threadIdx.x;
+
+ KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
+ }
+ }
+
+ __syncthreads();
+
+ half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}};
+
+#pragma unroll
+ for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) {
+ half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE];
+ half2 Q_k[ncols/nwarps];
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
+ const int i_KQ = i_KQ_0 + threadIdx.x;
+
+ K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
+ }
+#pragma unroll
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+ const int j_KQ = j_KQ_0 + threadIdx.y;
+
+ Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ];
+ }
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
+#pragma unroll
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+ sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps];
+ }
+ }
+ }
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
+ const int i_KQ = i_KQ_0 + threadIdx.x;
+
+#pragma unroll
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+ const int j_KQ = j_KQ_0 + threadIdx.y;
+
+ half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+ sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
+
+ kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
+
+ KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum;
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
+ const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]));
+ kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
+
+#pragma unroll
+ for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]);
+ const half2 val = h2exp(diff);
+ kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val;
+ KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) {
+ const int k = k0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) {
+ half2 V_k[(D/2)/WARP_SIZE][2];
+ half2 KQ_k[ncols/nwarps];
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i];
+ V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i];
+ }
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2];
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]);
+ VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]);
+ }
+ }
+ }
+
+ __syncthreads();
+ }
+
+#pragma unroll
+ for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
+ const int j_VKQ = j_VKQ_0 + threadIdx.y;
+
+ if (ic0 + j_VKQ >= ne01) {
+ return;
+ }
+
+ half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
+ kqsum_j = warp_reduce_sum(kqsum_j);
+
+#pragma unroll
+ for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
+ const int i0 = i00 + 2*threadIdx.x;
+
+ half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
+ if (parallel_blocks == 1) {
+ dst_val /= __half2half2(kqsum_j);
+ }
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
+ }
+
+ if (parallel_blocks != 1 && threadIdx.x == 0) {
+ dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+ }
+ }
+#else
+ NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+}
+
+template <int cols_per_block, int parallel_blocks>
+void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * Q = dst->src[0];
+ switch (Q->ne[0]) {
+ case 64: {
+ constexpr int D = 64;
+ constexpr int nwarps = 8;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+ } break;
+ case 128: {
+ constexpr int D = 128;
+ constexpr int nwarps = 8;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+ } break;
+ default: {
+ GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
+ } break;
+ }
+}
+
+void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+
+ const int32_t precision = KQV->op_params[2];
+ GGML_ASSERT(precision == GGML_PREC_DEFAULT);
+
+ if (Q->ne[1] <= 16) {
+ constexpr int cols_per_block = 16;
+ constexpr int parallel_blocks = 4;
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ return;
+ }
+
+ if (Q->ne[1] <= 32) {
+ constexpr int cols_per_block = 32;
+ constexpr int parallel_blocks = 4;
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ return;
+ }
+
+ constexpr int cols_per_block = 32;
+ constexpr int parallel_blocks = 1;
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+}
diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cuh b/ggml/src/ggml-cuda/fattn-tile-f16.cuh
new file mode 100644
index 00000000..ffc58784
--- /dev/null
+++ b/ggml/src/ggml-cuda/fattn-tile-f16.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu
new file mode 100644
index 00000000..15e22f49
--- /dev/null
+++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu
@@ -0,0 +1,312 @@
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-tile-f32.cuh"
+
+#define FATTN_KQ_STRIDE_TILE_F32 32
+
+template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_tile_ext_f32(
+ const char * __restrict__ Q,
+ const char * __restrict__ K,
+ const char * __restrict__ V,
+ const char * __restrict__ mask,
+ float * __restrict__ dst,
+ float2 * __restrict__ dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int ne03,
+ const int ne10,
+ const int ne11,
+ const int ne12,
+ const int ne13,
+ const int ne31,
+ const int nb31,
+ const int nb01,
+ const int nb02,
+ const int nb03,
+ const int nb11,
+ const int nb12,
+ const int nb13,
+ const int nb21,
+ const int nb22,
+ const int nb23,
+ const int ne0,
+ const int ne1,
+ const int ne2,
+ const int ne3) {
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+ const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
+ const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
+ const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+ const half * maskh = (const half *) mask + ne11*ic0;
+
+ const int stride_KV2 = nb11 / sizeof(half2);
+
+ const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+
+ static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+
+ __shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32];
+
+ __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts.
+ float2 * KV_tmp2 = (float2 *) KV_tmp;
+
+ float kqmax[ncols/nwarps];
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ kqmax[j0/nwarps] = -FLT_MAX/2.0f;
+ }
+ float kqsum[ncols/nwarps] = {0.0f};
+
+ float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
+
+ // Convert Q to half2 and store in registers:
+ __shared__ float Q_f[ncols][D];
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
+ float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
+ Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
+ Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
+ }
+ }
+
+ __syncthreads();
+
+ const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
+ for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
+ // Calculate KQ tile and keep track of new maximum KQ values:
+
+ float kqmax_new[ncols/nwarps];
+#pragma unroll
+ for (int j = 0; j < ncols/nwarps; ++j) {
+ kqmax_new[j] = kqmax[j];
+ }
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += nwarps) {
+ const int i_KQ = i_KQ_0 + threadIdx.y;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
+ const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
+ KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
+ KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
+ }
+ }
+
+ __syncthreads();
+
+ float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}};
+
+#pragma unroll
+ for (int k_KQ = 0; k_KQ < D; ++k_KQ) {
+ float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE];
+ float Q_k[ncols/nwarps];
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
+ const int i_KQ = i_KQ_0 + threadIdx.x;
+
+ K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
+ }
+#pragma unroll
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+ const int j_KQ = j_KQ_0 + threadIdx.y;
+
+ Q_k[j_KQ_0/nwarps] = Q_f[j_KQ][k_KQ];
+ }
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
+#pragma unroll
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+ sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE] * Q_k[j_KQ_0/nwarps];
+ }
+ }
+ }
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
+ const int i_KQ = i_KQ_0 + threadIdx.x;
+
+#pragma unroll
+ for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+ const int j_KQ = j_KQ_0 + threadIdx.y;
+
+ sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
+
+ kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+
+ KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F32 + i_KQ] = sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps];
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
+ const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
+ kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
+
+ float kqsum_add = 0.0f;
+#pragma unroll
+ for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F32; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ const float diff = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] - kqmax[j0/nwarps];
+ const float val = expf(diff);
+ kqsum_add += val;
+ KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] = val;
+ }
+ kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
+ VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F32; k0 += nwarps) {
+ const int k = k0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
+ KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) {
+ float2 V_k[(D/2)/WARP_SIZE];
+ float KQ_k[ncols/nwarps];
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i];
+ }
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ KQ_k[j0/nwarps] = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + k];
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps];
+ VKQ[j0/nwarps][i0/WARP_SIZE].y += V_k[i0/WARP_SIZE].y*KQ_k[j0/nwarps];
+ }
+ }
+ }
+
+ __syncthreads();
+ }
+
+#pragma unroll
+ for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
+ const int j_VKQ = j_VKQ_0 + threadIdx.y;
+
+ if (ic0 + j_VKQ >= ne01) {
+ return;
+ }
+
+ float kqsum_j = kqsum[j_VKQ_0/nwarps];
+ kqsum_j = warp_reduce_sum(kqsum_j);
+
+#pragma unroll
+ for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
+ const int i0 = i00 + 2*threadIdx.x;
+
+ float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
+ if (parallel_blocks == 1) {
+ dst_val.x /= kqsum_j;
+ dst_val.y /= kqsum_j;
+ }
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
+ }
+
+ if (parallel_blocks != 1 && threadIdx.x == 0) {
+ dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+ }
+ }
+}
+
+template <int cols_per_block, int parallel_blocks>
+void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * Q = dst->src[0];
+ switch (Q->ne[0]) {
+ case 64: {
+ constexpr int D = 64;
+ constexpr int nwarps = 8;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+ } break;
+ case 128: {
+ constexpr int D = 128;
+ constexpr int nwarps = 8;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+ } break;
+ default: {
+ GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
+ } break;
+ }
+}
+
+void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * Q = dst->src[0];
+
+ if (Q->ne[1] <= 16) {
+ constexpr int cols_per_block = 16;
+ constexpr int parallel_blocks = 4;
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ return;
+ }
+
+ if (Q->ne[1] <= 32) {
+ constexpr int cols_per_block = 32;
+ constexpr int parallel_blocks = 4;
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ return;
+ }
+
+ constexpr int cols_per_block = 32;
+ constexpr int parallel_blocks = 1;
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+}
diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cuh b/ggml/src/ggml-cuda/fattn-tile-f32.cuh
new file mode 100644
index 00000000..b1c546c8
--- /dev/null
+++ b/ggml/src/ggml-cuda/fattn-tile-f32.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
new file mode 100644
index 00000000..02a4ad07
--- /dev/null
+++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
@@ -0,0 +1,397 @@
+#include "common.cuh"
+#include "fattn-common.cuh"
+
+template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_vec_ext_f16(
+ const char * __restrict__ Q,
+ const char * __restrict__ K,
+ const char * __restrict__ V,
+ const char * __restrict__ mask,
+ float * __restrict__ dst,
+ float2 * __restrict__ dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int ne03,
+ const int ne10,
+ const int ne11,
+ const int ne12,
+ const int ne13,
+ const int ne31,
+ const int nb31,
+ const int nb01,
+ const int nb02,
+ const int nb03,
+ const int nb11,
+ const int nb12,
+ const int nb13,
+ const int nb21,
+ const int nb22,
+ const int nb23,
+ const int ne0,
+ const int ne1,
+ const int ne2,
+ const int ne3) {
+#ifdef FP16_AVAILABLE
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+ constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
+ constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
+ constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
+
+ const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+ Q += nb02* blockIdx.y + nb01*ic0;
+ K += nb12*(blockIdx.y / gqa_ratio);
+ V += nb22*(blockIdx.y / gqa_ratio);
+
+ const half * maskh = (const half *) mask + ne11*ic0;
+
+ const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+ const half slopeh = __float2half(slopef);
+
+ static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+ constexpr int nwarps = D / WARP_SIZE;
+ const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+ __builtin_assume(tid < D);
+
+ __shared__ half KQ[ncols*D];
+ half2 * KQ2 = (half2 *) KQ;
+
+ half kqmax[ncols];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ kqmax[j] = -HALF_MAX_HALF;
+ }
+ half kqsum[ncols] = {0.0f};
+
+ __shared__ half kqmax_shared[ncols][WARP_SIZE];
+ __shared__ half kqsum_shared[ncols][WARP_SIZE];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ if (threadIdx.y == 0) {
+ kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF;
+ kqsum_shared[j][threadIdx.x] = 0.0f;
+ }
+ }
+ __syncthreads();
+
+ // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
+ half2 Q_h2[ncols][D/(2*WARP_SIZE)];
+ int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)];
+ half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
+ if (Q_q8_1) {
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (j0 + nwarps > ncols && j >= ncols) {
+ break;
+ }
+
+ // Reuse KQ as temporary storage for converting Q to q8_1:
+ int * tmp_q_i32 = (int *) &KQ[j*D];
+ half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
+
+ // Set memory to zero if out of bounds:
+ if (ncols > 2 && ic0 + j >= ne01) {
+#pragma unroll
+ for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ tmp_q_i32[i] = 0;
+ }
+ if (threadIdx.x < D/QK8_1) {
+ tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f);
+ }
+ continue;
+ }
+
+ const float * Q_f = (const float *) (Q + j*nb01);
+#pragma unroll
+ for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+ quantize_q8_1_to_shared<half2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ int * tmp_q_i32 = (int *) &KQ[j*D];
+ half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
+ Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1];
+ }
+ }
+
+ __syncthreads();
+ } else {
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
+ Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
+ }
+ }
+ }
+
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ KQ[j*D + tid] = -HALF_MAX_HALF;
+ }
+
+ half2 VKQ[ncols] = {{0.0f, 0.0f}};
+
+ const int k_start = parallel_blocks == 1 ? 0 : ip*D;
+ for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
+ // Calculate KQ tile and keep track of new maximum KQ values:
+
+ // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
+ // see https://github.com/ggerganov/llama.cpp/pull/7061 .
+ // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
+ half kqmax_new = kqmax[0];
+ half kqmax_new_arr[ncols];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ kqmax_new_arr[j] = kqmax[j];
+ }
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
+ const int i_KQ = i_KQ_0 + threadIdx.y;
+
+ if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
+ break;
+ }
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
+ sum = warp_reduce_sum(sum);
+ sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
+
+ if (ncols == 1) {
+ kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
+ } else {
+ kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
+ }
+
+ if (threadIdx.x == 0) {
+ KQ[j*D + i_KQ] = sum;
+ }
+ }
+ }
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
+
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
+ if (threadIdx.x == 0) {
+ kqmax_shared[j][threadIdx.y] = kqmax_new_j;
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ half kqmax_new_j = kqmax_shared[j][threadIdx.x];
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+ const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
+ kqmax[j] = kqmax_new_j;
+
+ const half val = hexp(KQ[j*D + tid] - kqmax[j]);
+ kqsum[j] = kqsum[j]*KQ_max_scale + val;
+ KQ[j*D + tid] = val;
+
+ VKQ[j] *= __half2half2(KQ_max_scale);
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int k0 = 0; k0 < D; k0 += 2) {
+ if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
+ break;
+ }
+
+ half2 V_k;
+ reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
+ reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
+ }
+ }
+
+ __syncthreads();
+ }
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ kqsum[j] = warp_reduce_sum(kqsum[j]);
+ if (threadIdx.x == 0) {
+ kqsum_shared[j][threadIdx.y] = kqsum[j];
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+ if (ncols > 2 && ic0 + j_VKQ >= ne01) {
+ break;
+ }
+
+ kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
+ kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
+
+ half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
+ if (parallel_blocks == 1) {
+ dst_val /= kqsum[j_VKQ];
+ }
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
+ }
+
+ if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
+ dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
+ }
+#else
+ NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+}
+
+template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
+void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ constexpr int nwarps = D/WARP_SIZE;
+ fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
+ constexpr bool need_f16_K = D != 128;
+ constexpr bool need_f16_V = D != 128 && D != 64;
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
+}
+
+template <int D, ggml_type type_K, ggml_type type_V>
+void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_tensor * KQV = dst;
+ ggml_tensor * Q = dst->src[0];
+ ggml_tensor * K = dst->src[1];
+ ggml_tensor * V = dst->src[2];
+
+ const int32_t precision = KQV->op_params[2];
+ GGML_ASSERT(precision == GGML_PREC_DEFAULT);
+
+ GGML_ASSERT(K->type == type_K);
+ GGML_ASSERT(V->type == type_V);
+
+ if (Q->ne[1] == 1) {
+ constexpr int cols_per_block = 1;
+ constexpr int parallel_blocks = 4;
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ return;
+ }
+
+ if (Q->ne[1] == 2) {
+ constexpr int cols_per_block = 2;
+ constexpr int parallel_blocks = 4;
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ return;
+ }
+
+ if (Q->ne[1] <= 4) {
+ constexpr int cols_per_block = 4;
+ constexpr int parallel_blocks = 4;
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ return;
+ }
+
+ if (Q->ne[1] <= 8) {
+ constexpr int cols_per_block = 8;
+ constexpr int parallel_blocks = 4;
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ return;
+ }
+
+ constexpr int cols_per_block = 8;
+ constexpr int parallel_blocks = 1;
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+}
+
+#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
+ template void ggml_cuda_flash_attn_ext_vec_f16_case \
+ <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh
new file mode 100644
index 00000000..11a5e355
--- /dev/null
+++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh
@@ -0,0 +1,374 @@
+#include "common.cuh"
+#include "fattn-common.cuh"
+
+template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_vec_ext_f32(
+ const char * __restrict__ Q,
+ const char * __restrict__ K,
+ const char * __restrict__ V,
+ const char * __restrict__ mask,
+ float * __restrict__ dst,
+ float2 * __restrict__ dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int ne03,
+ const int ne10,
+ const int ne11,
+ const int ne12,
+ const int ne13,
+ const int ne31,
+ const int nb31,
+ const int nb01,
+ const int nb02,
+ const int nb03,
+ const int nb11,
+ const int nb12,
+ const int nb13,
+ const int nb21,
+ const int nb22,
+ const int nb23,
+ const int ne0,
+ const int ne1,
+ const int ne2,
+ const int ne3) {
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+ constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
+ constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
+ constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
+
+ const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+ Q += nb02* blockIdx.y + nb01*ic0;
+ K += nb12*(blockIdx.y / gqa_ratio);
+ V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
+ const half * maskh = (const half *) mask + ne11*ic0;
+
+ const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+
+ static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+ constexpr int nwarps = D / WARP_SIZE;
+ const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+ __builtin_assume(tid < D);
+
+ __shared__ float KQ[ncols*D];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ KQ[j*D + tid] = -FLT_MAX/2.0f;
+ }
+
+ float kqmax[ncols];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ kqmax[j] = -FLT_MAX/2.0f;
+ }
+ float kqsum[ncols] = {0.0f};
+
+ __shared__ float kqmax_shared[ncols][WARP_SIZE];
+ __shared__ float kqsum_shared[ncols][WARP_SIZE];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ if (threadIdx.y == 0) {
+ kqmax_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
+ kqsum_shared[j][threadIdx.x] = 0.0f;
+ }
+ }
+ __syncthreads();
+
+ // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
+ float2 Q_f2[ncols][D/(2*WARP_SIZE)];
+ int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)];
+ float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
+ if (Q_q8_1) {
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (j0 + nwarps > ncols && j >= ncols) {
+ break;
+ }
+
+ // Reuse KQ as temporary storage for converting Q to q8_1:
+ int * tmp_q_i32 = (int *) &KQ[j*D];
+ float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
+
+ // Set memory to zero if out of bounds:
+ if (ncols > 2 && ic0 + j >= ne01) {
+#pragma unroll
+ for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ tmp_q_i32[i] = 0;
+ }
+ if (threadIdx.x < D/QK8_1) {
+ tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
+ }
+ continue;
+ }
+
+ const float * Q_f = (const float *) (Q + j*nb01);
+#pragma unroll
+ for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+ quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ int * tmp_q_i32 = (int *) &KQ[j*D];
+ float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
+ Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1];
+ }
+ }
+
+ __syncthreads();
+ } else {
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
+ Q_f2[j][i0/WARP_SIZE].x *= scale;
+ Q_f2[j][i0/WARP_SIZE].y *= scale;
+ }
+ }
+ }
+
+ float VKQ[ncols] = {0.0f};
+
+ const int k_start = parallel_blocks == 1 ? 0 : ip*D;
+ for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
+ // Calculate KQ tile and keep track of new maximum KQ values:
+
+ float kqmax_new_arr[ncols];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ kqmax_new_arr[j] = kqmax[j];
+ }
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
+ const int i_KQ = i_KQ_0 + threadIdx.y;
+
+ if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
+ break;
+ }
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
+ sum = warp_reduce_sum(sum);
+ sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
+
+ kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
+
+ if (threadIdx.x == 0) {
+ KQ[j*D + i_KQ] = sum;
+ }
+ }
+ }
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ float kqmax_new_j = kqmax_new_arr[j];
+
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
+ if (threadIdx.x == 0) {
+ kqmax_shared[j][threadIdx.y] = kqmax_new_j;
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ float kqmax_new_j = kqmax_shared[j][threadIdx.x];
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+ const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
+ kqmax[j] = kqmax_new_j;
+
+ const float val = expf(KQ[j*D + tid] - kqmax[j]);
+ kqsum[j] = kqsum[j]*KQ_max_scale + val;
+ KQ[j*D + tid] = val;
+
+ VKQ[j] *= KQ_max_scale;
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int k = 0; k < D; ++k) {
+ if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) {
+ break;
+ }
+
+ const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ VKQ[j] += V_ki*KQ[j*D + k];
+ }
+ }
+
+ __syncthreads();
+ }
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ kqsum[j] = warp_reduce_sum(kqsum[j]);
+ if (threadIdx.x == 0) {
+ kqsum_shared[j][threadIdx.y] = kqsum[j];
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+ if (ncols > 2 && ic0 + j_VKQ >= ne01) {
+ break;
+ }
+
+ kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
+ kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
+
+ float dst_val = VKQ[j_VKQ];
+ if (parallel_blocks == 1) {
+ dst_val /= kqsum[j_VKQ];
+ }
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+ dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
+ }
+
+ if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
+ dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
+ }
+}
+
+template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
+void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ constexpr int nwarps = D/WARP_SIZE;
+ fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
+ constexpr bool need_f16_K = D != 128;
+ constexpr bool need_f16_V = D != 128 && D != 64;
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
+}
+
+template <int D, ggml_type type_K, ggml_type type_V>
+void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_tensor * Q = dst->src[0];
+ ggml_tensor * K = dst->src[1];
+ ggml_tensor * V = dst->src[2];
+
+ GGML_ASSERT(K->type == type_K);
+ GGML_ASSERT(V->type == type_V);
+
+ if (Q->ne[1] == 1) {
+ constexpr int cols_per_block = 1;
+ constexpr int parallel_blocks = 4;
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ return;
+ }
+
+ if (Q->ne[1] == 2) {
+ constexpr int cols_per_block = 2;
+ constexpr int parallel_blocks = 4;
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ return;
+ }
+
+ if (Q->ne[1] <= 4) {
+ constexpr int cols_per_block = 4;
+ constexpr int parallel_blocks = 4;
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ return;
+ }
+
+ if (Q->ne[1] <= 8) {
+ constexpr int cols_per_block = 8;
+ constexpr int parallel_blocks = 4;
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ return;
+ }
+
+ constexpr int cols_per_block = 8;
+ constexpr int parallel_blocks = 1;
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+}
+
+#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
+ template void ggml_cuda_flash_attn_ext_vec_f32_case \
+ <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
new file mode 100644
index 00000000..ae232224
--- /dev/null
+++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
@@ -0,0 +1,490 @@
+#include "common.cuh"
+#include "fattn-common.cuh"
+
+#ifdef FP16_MMA_AVAILABLE
+#include <mma.h>
+#endif // FP16_MMA_AVAILABLE
+
+// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
+template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_ext_f16(
+ const char * __restrict__ Q,
+ const char * __restrict__ K,
+ const char * __restrict__ V,
+ const char * __restrict__ mask,
+ float * __restrict__ dst,
+ float2 * __restrict__ dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int ne03,
+ const int ne10,
+ const int ne11,
+ const int ne12,
+ const int ne13,
+ const int ne31,
+ const int nb31,
+ const int nb01,
+ const int nb02,
+ const int nb03,
+ const int nb11,
+ const int nb12,
+ const int nb13,
+ const int nb21,
+ const int nb22,
+ const int nb23,
+ const int ne0,
+ const int ne1,
+ const int ne2,
+ const int ne3) {
+#ifdef FP16_MMA_AVAILABLE
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+ const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
+ const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+ static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
+ static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
+ constexpr int frag_m = ncols == 8 ? 32 : 16;
+ constexpr int frag_n = ncols == 8 ? 8 : 16;
+ static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
+ typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
+
+ constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
+ constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
+ static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
+
+ // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
+ constexpr int D_padded = D + 8;
+ constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
+ constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
+
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+ const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
+ const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
+ const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+ const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
+ const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
+
+ const int stride_Q = nb01 / sizeof(float);
+ const int stride_KV = nb11 / sizeof(half);
+
+ const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+ const half slopeh = __float2half(slopef);
+ const half2 slope2 = make_half2(slopef, slopef);
+
+ frag_b Q_b[D/16][ncols/frag_n];
+
+ // A single buffer for temporarily holding tiles of KQ and VKQ parts:
+ constexpr int mem_KQ = ncols*kqs_padded*kqar;
+ constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
+ __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
+ float * KQ_f = (float *) KQ;
+ half2 * KQ2 = (half2 *) KQ;
+
+ float KQ_rowsum_f[ncols/nwarps] = {0.0f};
+ float KQ_max_f[ncols/nwarps];
+ float KQ_max_scale_f[ncols/nwarps] = {0.0f};
+
+#pragma unroll
+ for (int j = 0; j < ncols/nwarps; ++j) {
+ KQ_max_f[j] = -FLT_MAX/2.0f;
+ }
+
+ half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+ half2 KQ_max_h2[ncols/nwarps];
+ half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+
+#pragma unroll
+ for (int j = 0; j < ncols/nwarps; ++j) {
+ KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
+ }
+
+ __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
+ half2 * VKQ2 = (half2 *) VKQ;
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+ break;
+ }
+ VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
+ }
+ }
+
+ // Convert Q to half and apply scale, temporarily store in KQ:
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+#pragma unroll
+ for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + WARP_SIZE > D && i >= D) {
+ break;
+ }
+ KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
+ }
+ }
+
+ __syncthreads();
+
+ // Load Q into tensor core fragments/registers since it will be used frequently:
+#pragma unroll
+ for (int i0 = 0; i0 < D; i0 += 16) {
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+ nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+ }
+ }
+
+ __syncthreads();
+
+ // Iterate over ne11 == previous tokens:
+ for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
+ // Calculate tile of KQ:
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
+ frag_c_KQ KQ_c[ncols/frag_n];
+#pragma unroll
+ for (int j = 0; j < ncols/frag_n; ++j) {
+ nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
+ }
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
+ frag_a_K K_a;
+ nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+#pragma unroll
+ for (int j = 0; j < ncols/frag_n; ++j) {
+ nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
+ }
+ }
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+ nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
+ }
+ }
+
+ __syncthreads();
+
+ // Calculate softmax for each KQ column using the current max. value.
+ // The divisor is stored in KQ_rowsum and will be applied at the end.
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (std::is_same<KQ_acc_t, float>::value) {
+ float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+ const int k = k0 + threadIdx.x;
+
+ KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+ }
+
+ float KQ_max_new = KQ_max_f[j0/nwarps];
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+ const int k = k0 + threadIdx.x;
+
+ KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+ KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
+ }
+ KQ_max_new = warp_reduce_max(KQ_max_new);
+
+ const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
+ KQ_max_scale_f[j0/nwarps] = expf(diff);
+ if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+ KQ_max_scale_f[j0/nwarps] = 0.0f;
+ }
+ KQ_max_f[j0/nwarps] = KQ_max_new;
+
+ float KQ_rowsum_add = 0.0f;
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+ const int k = k0 + threadIdx.x;
+
+ const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
+ KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
+ if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+ KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
+ }
+ KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
+ KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
+ }
+ KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+ KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
+ } else {
+ half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+ const int k = k0 + threadIdx.x;
+
+ KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+ }
+
+ half2 KQ_max_new = KQ_max_h2[j0/nwarps];
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+ const int k = k0 + threadIdx.x;
+
+ KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+ KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
+ }
+ KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
+ const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
+ KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
+ const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+ *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
+ KQ_max_h2[j0/nwarps] = KQ_max_new;
+
+ half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+ const int k = k0 + threadIdx.x;
+
+ const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
+ KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
+ const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+ *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
+ KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
+ KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
+ }
+ KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+ KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
+ }
+ }
+
+ __syncthreads();
+
+ frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+ const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+ nvcuda::wmma::load_matrix_sync(
+ KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
+ KQ + j0*(kqar*kqs_padded) + k,
+ kqar*kqs_padded);
+ }
+ }
+
+ frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
+#pragma unroll
+ for (int j = 0; j < ncols/frag_n; ++j) {
+ nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
+ }
+
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+ const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+
+ frag_a_V v_a;
+ nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+#pragma unroll
+ for (int j = 0; j < ncols/frag_n; ++j) {
+ nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
+ }
+ }
+ }
+
+ __syncthreads();
+
+ const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+ nvcuda::wmma::store_matrix_sync(
+ KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
+ VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
+ D_padded, nvcuda::wmma::mem_col_major);
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ half2 VKQ_scale;
+ if (std::is_same<KQ_acc_t, float>::value) {
+ VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
+ } else {
+ VKQ_scale = KQ_max_scale_h2[j0/nwarps];
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+ break;
+ }
+
+ half2 VKQ_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+ for (int l = 0; l < VKQ_ratio; ++l) {
+ VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
+ }
+ VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
+ }
+ }
+
+ __syncthreads();
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j_VKQ = j0 + threadIdx.y;
+ if (ic0 + j_VKQ >= ne01) {
+ return;
+ }
+ const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+
+ float KQ_rowsum_j;
+ if (std::is_same<KQ_acc_t, float>::value) {
+ KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
+ } else {
+ KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + WARP_SIZE > D && i >= D) {
+ break;
+ }
+ float dst_val = VKQ[j_VKQ*D_padded + i];
+ if (parallel_blocks == 1) {
+ dst_val /= KQ_rowsum_j;
+ }
+ dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
+ }
+
+ if (parallel_blocks == 1 || threadIdx.x != 0) {
+ continue;
+ }
+
+ float2 dst_meta_val;
+ if (std::is_same<KQ_acc_t, float>::value) {
+ dst_meta_val.x = KQ_max_f[j0/nwarps];
+ } else {
+ dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
+ }
+ dst_meta_val.y = KQ_rowsum_j;
+ dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
+ }
+#else
+ NO_DEVICE_CODE;
+#endif // FP16_MMA_AVAILABLE
+}
+
+constexpr int get_max_power_of_2(int x) {
+ return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
+}
+
+static_assert(get_max_power_of_2(1) == 1, "Test failed.");
+static_assert(get_max_power_of_2(2) == 2, "Test failed.");
+static_assert(get_max_power_of_2(4) == 4, "Test failed.");
+static_assert(get_max_power_of_2(6) == 2, "Test failed.");
+
+// Number of VKQ rows calculated in parallel:
+constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
+ return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
+}
+
+static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
+static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
+static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
+static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
+static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
+static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
+static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
+
+template <int D, int cols_per_block, typename KQ_acc_t>
+void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * Q = dst->src[0];
+
+ constexpr int nwarps = 4;
+
+ constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
+ const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
+ const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+
+ if (4*blocks_num_pb1 < 2*nsm) {
+ constexpr int parallel_blocks = 4;
+ fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+ return;
+ }
+ if (2*blocks_num_pb1 < 2*nsm) {
+ constexpr int parallel_blocks = 2;
+ fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+ return;
+ }
+ constexpr int parallel_blocks = 1;
+ fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+}
+
+#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \
+ template void ggml_cuda_flash_attn_ext_wmma_f16_case \
+ <D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE(112, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE(128, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE(112, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE(128, 32, float);
+// extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 8, half);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 8, half);
+extern DECL_FATTN_WMMA_F16_CASE(128, 8, half);
+extern DECL_FATTN_WMMA_F16_CASE(256, 8, half);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE(112, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE(128, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE(112, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE(128, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
new file mode 100644
index 00000000..38d30b21
--- /dev/null
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -0,0 +1,345 @@
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-tile-f16.cuh"
+#include "fattn-tile-f32.cuh"
+#include "fattn-vec-f16.cuh"
+#include "fattn-vec-f32.cuh"
+#include "fattn-wmma-f16.cuh"
+#include "fattn.cuh"
+
+#include <cstdint>
+
+static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+
+ const int32_t precision = KQV->op_params[2];
+
+ if (precision != GGML_PREC_DEFAULT) {
+ if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
+ constexpr int cols_per_block = 16;
+ switch (Q->ne[0]) {
+ case 64:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
+ break;
+ case 80:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
+ break;
+ case 96:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
+ break;
+ case 112:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
+ break;
+ case 128:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+ break;
+ case 256:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+ } else {
+ constexpr int cols_per_block = 32;
+ switch (Q->ne[0]) {
+ case 64:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
+ break;
+ case 80:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
+ break;
+ case 96:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
+ break;
+ case 112:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
+ break;
+ case 128:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+ break;
+ // case 256:
+ // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+ // break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+ }
+ return;
+ }
+
+ if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
+ constexpr int cols_per_block = 8;
+ switch (Q->ne[0]) {
+ case 64:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+ break;
+ case 96:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+ break;
+ case 128:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+ break;
+ case 256:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+ return;
+ }
+
+ if (Q->ne[1] <= 32) {
+ constexpr int cols_per_block = 16;
+ switch (Q->ne[0]) {
+ case 64:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+ break;
+ case 80:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+ break;
+ case 96:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+ break;
+ case 112:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+ break;
+ case 128:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+ break;
+ case 256:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+ return;
+ }
+
+ constexpr int cols_per_block = 32;
+ switch (Q->ne[0]) {
+ case 64:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+ break;
+ case 80:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+ break;
+ case 96:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+ break;
+ case 112:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+ break;
+ case 128:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+ break;
+ case 256:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+}
+#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
+ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
+ ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
+ return; \
+ } \
+
+static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_tensor * Q = dst->src[1];
+ ggml_tensor * K = dst->src[1];
+ ggml_tensor * V = dst->src[2];
+
+#ifdef GGML_CUDA_FA_ALL_QUANTS
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 )
+
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
+
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
+
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
+
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
+
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
+
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
+
+ FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#else
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+
+ FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
+ FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
+ FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#endif // GGML_CUDA_FA_ALL_QUANTS
+
+ on_no_fattn_vec_case(Q->ne[0]);
+}
+
+#define FATTN_VEC_F32_CASE(D, type_K, type_V) \
+ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
+ ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \
+ return; \
+ } \
+
+static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_tensor * Q = dst->src[1];
+ ggml_tensor * K = dst->src[1];
+ ggml_tensor * V = dst->src[2];
+
+#ifdef GGML_CUDA_FA_ALL_QUANTS
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
+
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
+
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
+
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
+
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
+
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
+
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
+
+ FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#else
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+
+ FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
+ FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
+ FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#endif // GGML_CUDA_FA_ALL_QUANTS
+
+ on_no_fattn_vec_case(Q->ne[0]);
+}
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+
+ ggml_cuda_set_device(ctx.device);
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ const int32_t precision = KQV->op_params[2];
+
+ // On AMD the tile kernels perform poorly, use the vec kernel instead:
+ if (cc >= CC_OFFSET_AMD) {
+ if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+ }
+ return;
+ }
+
+ if (!fast_fp16_available(cc)) {
+ if (Q->ne[1] <= 8) {
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
+ }
+ return;
+ }
+
+ if (!fp16_mma_available(cc)) {
+ if (Q->ne[1] <= 8) {
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
+ }
+ return;
+ }
+
+ if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
+ if (precision == GGML_PREC_DEFAULT) {
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+ return;
+ } else if(Q->ne[0] <= 128) {
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+ return;
+ }
+ }
+
+ ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+}
diff --git a/ggml/src/ggml-cuda/fattn.cuh b/ggml/src/ggml-cuda/fattn.cuh
new file mode 100644
index 00000000..ad3ca7a8
--- /dev/null
+++ b/ggml/src/ggml-cuda/fattn.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu
new file mode 100644
index 00000000..55af195f
--- /dev/null
+++ b/ggml/src/ggml-cuda/getrows.cu
@@ -0,0 +1,178 @@
+#include "getrows.cuh"
+#include "dequantize.cuh"
+
+template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static __global__ void k_get_rows(
+ const void * src0, const int32_t * src1, dst_t * dst,
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+ size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+ const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+ const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+ const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+ if (i00 >= ne00) {
+ return;
+ }
+
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
+
+ const int ib = i00/qk; // block index
+ const int iqs = (i00%qk)/qr; // quant index
+ const int iybs = i00 - i00%qk; // dst block start index
+ const int y_offset = qr == 1 ? 1 : qk/2;
+
+ // dequantize
+ dfloat2 v;
+ dequantize_kernel(src0_row, ib, iqs, v);
+
+ dst_row[iybs + iqs + 0] = v.x;
+ dst_row[iybs + iqs + y_offset] = v.y;
+}
+
+template<typename src0_t, typename dst_t>
+static __global__ void k_get_rows_float(
+ const src0_t * src0, const int32_t * src1, dst_t * dst,
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+ size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+ const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+ const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+ const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+ if (i00 >= ne00) {
+ return;
+ }
+
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+
+ dst_row[i00] = src0_row[i00];
+}
+
+template<int qk, int qr, dequantize_kernel_t dq>
+static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+ const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
+ const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+ // strides in elements
+ //const size_t s0 = nb0 / ggml_element_size(dst);
+ const size_t s1 = nb1 / ggml_element_size(dst);
+ const size_t s2 = nb2 / ggml_element_size(dst);
+ const size_t s3 = nb3 / ggml_element_size(dst);
+
+ const size_t s10 = nb10 / ggml_element_size(src1);
+ const size_t s11 = nb11 / ggml_element_size(src1);
+ const size_t s12 = nb12 / ggml_element_size(src1);
+ //const size_t s13 = nb13 / ggml_element_size(src1);
+
+ GGML_ASSERT(ne00 % 2 == 0);
+
+ k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd,
+ ne00, /*ne01, ne02, ne03,*/
+ /*ne10, ne11,*/ ne12, /*ne13,*/
+ /* s0,*/ s1, s2, s3,
+ /* nb00,*/ nb01, nb02, nb03,
+ s10, s11, s12/*, s13*/);
+
+ GGML_UNUSED(dst);
+}
+
+template<typename src0_t>
+static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+ const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
+ const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+ // strides in elements
+ //const size_t s0 = nb0 / ggml_element_size(dst);
+ const size_t s1 = nb1 / ggml_element_size(dst);
+ const size_t s2 = nb2 / ggml_element_size(dst);
+ const size_t s3 = nb3 / ggml_element_size(dst);
+
+ const size_t s10 = nb10 / ggml_element_size(src1);
+ const size_t s11 = nb11 / ggml_element_size(src1);
+ const size_t s12 = nb12 / ggml_element_size(src1);
+ //const size_t s13 = nb13 / ggml_element_size(src1);
+
+ k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd,
+ ne00, /*ne01, ne02, ne03,*/
+ /*ne10, ne11,*/ ne12, /*ne13,*/
+ /* s0,*/ s1, s2, s3,
+ /* nb00,*/ nb01, nb02, nb03,
+ s10, s11, s12/*, s13*/);
+
+ GGML_UNUSED(dst);
+}
+
+void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const float * src0_d = (const float *)src0->data;
+ const float * src1_d = (const float *)src1->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+ GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+ GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
+
+ const int32_t * src1_i32 = (const int32_t *) src1_d;
+
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
+ break;
+ case GGML_TYPE_F32:
+ get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ break;
+ case GGML_TYPE_Q4_0:
+ get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ break;
+ case GGML_TYPE_Q4_1:
+ get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ break;
+ case GGML_TYPE_Q5_0:
+ get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ break;
+ case GGML_TYPE_Q5_1:
+ get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ break;
+ case GGML_TYPE_Q8_0:
+ get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ break;
+ default:
+ // TODO: k-quants
+ fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
+ GGML_ASSERT(false);
+ break;
+ }
+}
diff --git a/ggml/src/ggml-cuda/getrows.cuh b/ggml/src/ggml-cuda/getrows.cuh
new file mode 100644
index 00000000..bbf13023
--- /dev/null
+++ b/ggml/src/ggml-cuda/getrows.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_GET_ROWS_BLOCK_SIZE 256
+
+void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu
new file mode 100644
index 00000000..3d0d8d4e
--- /dev/null
+++ b/ggml/src/ggml-cuda/im2col.cu
@@ -0,0 +1,104 @@
+#include "im2col.cuh"
+
+template <typename T>
+static __global__ void im2col_kernel(
+ const float * x, T * dst, int64_t batch_offset,
+ int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
+ int s0, int s1, int p0, int p1, int d0, int d1) {
+ const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
+ if (i >= pelements) {
+ return;
+ }
+
+ const int64_t ksize = OW * (KH > 1 ? KW : 1);
+ const int64_t kx = i / ksize;
+ const int64_t kd = kx * ksize;
+ const int64_t ky = (i - kd) / OW;
+ const int64_t ix = i % OW;
+
+ const int64_t oh = blockIdx.y;
+ const int64_t batch = blockIdx.z / IC;
+ const int64_t ic = blockIdx.z % IC;
+
+ const int64_t iiw = ix * s0 + kx * d0 - p0;
+ const int64_t iih = oh * s1 + ky * d1 - p1;
+
+ const int64_t offset_dst =
+ ((batch * OH + oh) * OW + ix) * CHW +
+ (ic * (KW * KH) + ky * KW + kx);
+
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+ dst[offset_dst] = 0.0f;
+ } else {
+ const int64_t offset_src = ic * offset_delta + batch * batch_offset;
+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
+ }
+}
+
+template <typename T>
+static void im2col_cuda(const float * x, T* dst,
+ int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
+ int64_t batch, int64_t batch_offset, int64_t offset_delta,
+ int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+ const int parallel_elements = OW * KW * KH;
+ const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
+ dim3 block_nums(num_blocks, OH, batch * IC);
+ im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+}
+
+static void im2col_cuda_f16(const float * x, half * dst,
+ int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
+ int64_t batch, int64_t batch_offset, int64_t offset_delta,
+ int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+
+ im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
+}
+
+static void im2col_cuda_f32(const float * x, float * dst,
+ int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
+ int64_t batch, int64_t batch_offset, int64_t offset_delta,
+ int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+
+ im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
+}
+
+void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const float * src1_d = (const float *)src1->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
+
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+ const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+ const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
+
+ const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
+
+ const int64_t IC = src1->ne[is_2D ? 2 : 1];
+ const int64_t IH = is_2D ? src1->ne[1] : 1;
+ const int64_t IW = src1->ne[0];
+
+ const int64_t KH = is_2D ? src0->ne[1] : 1;
+ const int64_t KW = src0->ne[0];
+
+ const int64_t OH = is_2D ? dst->ne[2] : 1;
+ const int64_t OW = dst->ne[1];
+
+ const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+ const int64_t batch = src1->ne[3];
+ const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
+
+ if(dst->type == GGML_TYPE_F16) {
+ im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
+ } else {
+ im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
+ }
+}
diff --git a/ggml/src/ggml-cuda/im2col.cuh b/ggml/src/ggml-cuda/im2col.cuh
new file mode 100644
index 00000000..1ce8fae4
--- /dev/null
+++ b/ggml/src/ggml-cuda/im2col.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_IM2COL_BLOCK_SIZE 256
+
+void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh
new file mode 100644
index 00000000..a452a3cc
--- /dev/null
+++ b/ggml/src/ggml-cuda/mma.cuh
@@ -0,0 +1,221 @@
+#include "common.cuh"
+
+struct mma_int_A_I16K4 {
+ static constexpr int I = 16;
+ static constexpr int K = 4;
+ static constexpr int ne = 2;
+
+ int x[ne] = {0};
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ const int ret = (l%2) * (I/2) + threadIdx.x / K;
+ GGML_CUDA_ASSUME(ret >= 0);
+ GGML_CUDA_ASSUME(ret < I);
+ return ret;
+ }
+
+ static __device__ __forceinline__ int get_k(const int /* l */) {
+ const int ret = threadIdx.x % K;
+ GGML_CUDA_ASSUME(ret >= 0);
+ GGML_CUDA_ASSUME(ret < K);
+ return ret;
+ }
+
+ __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE)
+ const int * xs = xs0 + (threadIdx.x%I)*stride;
+ asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+ : "+r"(x[0]), "+r"(x[1])
+ : "l"(xs));
+#else
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ x[l] = xs0[get_i(l)*stride + get_k(l)];
+ }
+#endif // defined(INT8_MMA_AVAILABLE)
+ }
+};
+
+struct mma_int_A_I16K8 {
+ static constexpr int I = 16;
+ static constexpr int K = 8;
+ static constexpr int ne = 4;
+
+ int x[ne] = {0};
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
+ GGML_CUDA_ASSUME(ret >= 0);
+ GGML_CUDA_ASSUME(ret < I);
+ return ret;
+ }
+
+ static __device__ __forceinline__ int get_k(const int l) {
+ const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
+ GGML_CUDA_ASSUME(ret >= 0);
+ GGML_CUDA_ASSUME(ret < K);
+ return ret;
+ }
+
+ __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE)
+ const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
+ asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
+ : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+ : "l"(xs));
+#else
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ x[l] = xs0[get_i(l)*stride + get_k(l)];
+ }
+#endif // defined(INT8_MMA_AVAILABLE)
+ }
+
+ __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
+ ((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
+ }
+};
+
+struct mma_int_B_J8K4 {
+ static constexpr int J = 8;
+ static constexpr int K = 4;
+ static constexpr int ne = 1;
+
+ int x[ne] = {0};
+
+ static __device__ __forceinline__ int get_j(const int /* l */) {
+ const int ret = threadIdx.x / K;
+ GGML_CUDA_ASSUME(ret >= 0);
+ GGML_CUDA_ASSUME(ret < J);
+ return ret;
+ }
+
+ static __device__ __forceinline__ int get_k(const int /* l */) {
+ const int ret = threadIdx.x % K;
+ GGML_CUDA_ASSUME(ret >= 0);
+ GGML_CUDA_ASSUME(ret < K);
+ return ret;
+ }
+
+ __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
+ const int * xs = xs0 + (threadIdx.x%J)*stride;
+ asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
+ : "+r"(x[0])
+ : "l"(xs));
+#else
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ x[l] = xs0[get_j(l)*stride + get_k(l)];
+ }
+#endif // defined(INT8_MMA_AVAILABLE)
+ }
+};
+
+struct mma_int_B_J8K8 {
+ static constexpr int J = 8;
+ static constexpr int K = 8;
+ static constexpr int ne = 2;
+
+ int x[ne] = {0};
+
+ static __device__ __forceinline__ int get_j(const int /* l */) {
+ const int ret = threadIdx.x / (K/2);
+ GGML_CUDA_ASSUME(ret >= 0);
+ GGML_CUDA_ASSUME(ret < J);
+ return ret;
+ }
+
+ static __device__ __forceinline__ int get_k(const int l) {
+ const int ret = l * (K/2) + threadIdx.x % (K/2);
+ GGML_CUDA_ASSUME(ret >= 0);
+ GGML_CUDA_ASSUME(ret < K);
+ return ret;
+ }
+
+ __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
+ const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
+ asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+ : "+r"(x[0]), "+r"(x[1])
+ : "l"(xs));
+#else
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ x[l] = xs0[get_j(l)*stride + get_k(l)];
+ }
+#endif // defined(INT8_MMA_AVAILABLE)
+ }
+};
+
+struct mma_int_C_I16J8 {
+ static constexpr int I = 16;
+ static constexpr int J = 8;
+ static constexpr int ne = 4;
+
+ int x[ne] = {0};
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
+ GGML_CUDA_ASSUME(ret >= 0);
+ GGML_CUDA_ASSUME(ret < I);
+ return ret;
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
+ GGML_CUDA_ASSUME(ret >= 0);
+ GGML_CUDA_ASSUME(ret < J);
+ return ret;
+ }
+
+ __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
+#ifdef INT8_MMA_AVAILABLE
+#if __CUDA_ARCH__ >= CC_AMPERE
+ asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+ : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+ : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+#else
+ // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+ : "+r"(x[0]), "+r"(x[1])
+ : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+ : "+r"(x[2]), "+r"(x[3])
+ : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+#endif // __CUDA_ARCH__ >= CC_AMPERE
+#else
+ GGML_UNUSED(mma_A);
+ GGML_UNUSED(mma_B);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
+#ifdef INT8_MMA_AVAILABLE
+#if __CUDA_ARCH__ >= CC_AMPERE
+ asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+ : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+ : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
+#else
+ // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+ : "+r"(x[0]), "+r"(x[1])
+ : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+ : "+r"(x[2]), "+r"(x[3])
+ : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+ : "+r"(x[0]), "+r"(x[1])
+ : "r"(mma_A.x[2]), "r"(mma_B.x[1]));
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+ : "+r"(x[2]), "+r"(x[3])
+ : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
+#endif // __CUDA_ARCH__ >= CC_AMPERE
+#else
+ GGML_UNUSED(mma_A);
+ GGML_UNUSED(mma_B);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+ }
+};
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
new file mode 100644
index 00000000..84f6387e
--- /dev/null
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -0,0 +1,150 @@
+#include "mmq.cuh"
+
+void ggml_cuda_op_mul_mat_q(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+ const int64_t ne00 = src0->ne[0];
+
+ const int64_t nb01 = src0->nb[1];
+
+ const int64_t ne10 = src1->ne[0];
+ const int64_t ne11 = src1->ne[1];
+ GGML_ASSERT(ne10 % QK8_1 == 0);
+
+ const int64_t ne0 = dst->ne[0];
+
+ const int64_t row_diff = row_high - row_low;
+ const int64_t stride00 = nb01 / ggml_type_size(src0->type);
+
+ int id = ggml_cuda_get_device();
+ const int compute_capability = ggml_cuda_info().devices[id].cc;
+
+ // the main device has a larger memory buffer to hold the results from all GPUs
+ // nrows_dst == nrows of the matrix that the kernel writes into
+ const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
+
+ const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst};
+
+ switch (src0->type) {
+ case GGML_TYPE_Q4_0:
+ mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q4_1:
+ mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q5_0:
+ mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q5_1:
+ mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q8_0:
+ mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q2_K:
+ mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q3_K:
+ mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q4_K:
+ mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q5_K:
+ mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q6_K:
+ mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ2_XXS:
+ mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ2_XS:
+ mul_mat_q_case<GGML_TYPE_IQ2_XS>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ2_S:
+ mul_mat_q_case<GGML_TYPE_IQ2_S>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ3_XXS:
+ mul_mat_q_case<GGML_TYPE_IQ3_XXS>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ3_S:
+ mul_mat_q_case<GGML_TYPE_IQ3_S>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ1_S:
+ mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ4_XS:
+ mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ4_NL:
+ mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+
+ GGML_UNUSED(src1);
+ GGML_UNUSED(dst);
+ GGML_UNUSED(src1_ddf_i);
+}
+
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
+#ifdef GGML_CUDA_FORCE_CUBLAS
+ return false;
+#endif // GGML_CUDA_FORCE_CUBLAS
+
+ bool mmq_supported;
+
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ4_NL:
+ mmq_supported = true;
+ break;
+ default:
+ mmq_supported = false;
+ break;
+ }
+
+ if (!mmq_supported) {
+ return false;
+ }
+
+ if (int8_mma_available(cc)) {
+ return true;
+ }
+
+ if (cc < MIN_CC_DP4A) {
+ return false;
+ }
+
+#ifdef GGML_CUDA_FORCE_MMQ
+ return true;
+#endif //GGML_CUDA_FORCE_MMQ
+
+ if (cc < CC_OFFSET_AMD) {
+ return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+ }
+
+ return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+}
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
new file mode 100644
index 00000000..f08a4758
--- /dev/null
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -0,0 +1,2936 @@
+#pragma once
+
+#include "common.cuh"
+#include "vecdotq.cuh"
+#include "mma.cuh"
+
+#include <climits>
+#include <cstdint>
+
+#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
+#define MMQ_ITER_K 256
+#define MMQ_NWARPS 8
+
+typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
+typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00);
+typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
+
+enum mmq_q8_1_ds_layout {
+ MMQ_Q8_1_DS_LAYOUT_D4,
+ MMQ_Q8_1_DS_LAYOUT_DS4,
+ MMQ_Q8_1_DS_LAYOUT_D2S6,
+};
+
+struct block_q8_1_mmq {
+ // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block.
+ // The y float data is first grouped as blocks of 128 values.
+ // These blocks are then treated as individual data values and transposed.
+ //
+ // To avoid shared memory bank conflicts each block is padded with 16 bytes.
+ // This padding is also used to store block scales/partial sums.
+ // The scales multiplied with the quantized data are equal to the unquantized values.
+ // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization)
+ // and are only needed for performance reasons.
+ //
+ // The exact data stored depends on the x data type.
+ union {
+ float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3
+ half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3
+ half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values,
+ // stored as d0,d1,s1,s2,s3,s4,s5
+ };
+ int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
+};
+static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
+static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
+
+static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
+ switch (type_x) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
+ case GGML_TYPE_Q5_0:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_Q5_1:
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
+ case GGML_TYPE_Q8_0:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_Q2_K:
+ return MMQ_Q8_1_DS_LAYOUT_D2S6;
+ case GGML_TYPE_Q3_K:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ3_S:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_IQ1_S:
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ4_NL:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+}
+
+struct tile_x_sizes {
+ int qs;
+ int dm;
+ int sc;
+};
+
+static constexpr int get_mmq_x_max_host(const int cc) {
+ return int8_mma_available(cc) ? 128 :
+#ifdef GGML_CUDA_FORCE_MMQ
+ cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
+#else
+ cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
+#endif // GGML_CUDA_FORCE_MMQ
+}
+
+static constexpr __device__ int get_mmq_x_max_device() {
+#ifdef INT8_MMA_AVAILABLE
+ return 128;
+#else // INT8_MMA_AVAILABLE
+
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+ return 128;
+#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+
+#if __CUDA_ARCH__ >= CC_VOLTA
+#ifdef GGML_CUDA_FORCE_MMQ
+ return MMQ_DP4A_MAX_BATCH_SIZE;
+#else // GGML_CUDA_FORCE_MMQ
+ return 128;
+#endif // GGML_CUDA_FORCE_MMQ
+#else // __CUDA_ARCH__ >= CC_VOLTA
+
+ return 64;
+#endif // __CUDA_ARCH__ >= CC_VOLTA
+
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#endif // INT8_MMA_AVAILABLE
+}
+
+static constexpr int get_mmq_y_host(const int cc) {
+ return cc >= CC_OFFSET_AMD ? (cc == CC_RDNA1 ? 64 : 128) : (cc >= CC_VOLTA ? 128 : 64);
+}
+
+static constexpr __device__ int get_mmq_y_device() {
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA1)
+ return 64;
+#else
+ return 128;
+#endif // defined RDNA1
+#else
+#if __CUDA_ARCH__ >= CC_VOLTA
+ return 128;
+#else
+ return 64;
+#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+}
+
+#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
+#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
+#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
+#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0}
+#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
+#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
+#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
+
+static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
+ return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
+ type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
+ type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
+ type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
+ type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
+ type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
+ type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
+ type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
+ type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
+ type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
+ type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
+ type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
+ type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
+ type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
+ type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
+ type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
+ type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
+ type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
+ tile_x_sizes{0, 0, 0};
+}
+
+#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
+#define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
+#define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4)
+#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2 + 4)
+#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7)
+
+static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
+
+static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
+ return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
+ type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
+ type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
+ type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
+ type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
+ type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
+ type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
+ type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
+ type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
+ type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
+ type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
+ 0;
+}
+
+#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
+
+static int mmq_get_granularity_host(const int mmq_x, const int cc) {
+ return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
+}
+
+#ifdef INT8_MMA_AVAILABLE
+static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
+ return mmq_x >= 48 ? 16 : 8;
+}
+#else
+static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
+ return 8;
+}
+#endif // INT8_MMA_AVAILABLE
+
+// ------------------------------------------------------------
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + 2*WARP_SIZE);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kbx = threadIdx.x / QI4_0;
+ const int kqsx = threadIdx.x % QI4_0;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
+ const int qs0 = get_int_b2(bxi->qs, kqsx);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
+#else
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
+ int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
+#else
+ x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
+
+ int u[2*VDR_Q4_0_Q8_1_MMQ];
+
+#pragma unroll
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
+ }
+
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
+ (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u,
+ x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kbx = threadIdx.x / QI4_1;
+ const int kqsx = threadIdx.x % QI4_1;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
+ const int qs0 = get_int_b4(bxi->qs, kqsx);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
+#else
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
+ int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
+#else
+ x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
+
+ int u[2*VDR_Q4_1_Q8_1_MMQ];
+
+#pragma unroll
+ for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
+ }
+
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
+ (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u,
+ x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kbx = threadIdx.x / QI5_0;
+ const int kqsx = threadIdx.x % QI5_0;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
+
+ const int ql = get_int_b2(bxi->qs, kqsx);
+ const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
+
+ int qs0 = (ql >> 0) & 0x0F0F0F0F;
+ qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
+ qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
+ qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
+ qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
+ qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
+
+ int qs1 = (ql >> 4) & 0x0F0F0F0F;
+ qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
+ qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
+ qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
+ qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
+ qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
+ x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
+ int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
+#else
+ x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kbx = threadIdx.x / QI5_1;
+ const int kqsx = threadIdx.x % QI5_1;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
+
+ const int ql = get_int_b4(bxi->qs, kqsx);
+ const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
+
+ int qs0 = (ql >> 0) & 0x0F0F0F0F;
+ qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
+ qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
+ qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
+ qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
+
+ int qs1 = (ql >> 4) & 0x0F0F0F0F;
+ qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
+ qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
+ qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
+ qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
+ x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
+ int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
+#else
+ x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_tile + 2*WARP_SIZE);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kbx = threadIdx.x / QI8_0;
+ const int kqsx = threadIdx.x % QI8_0;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx);
+ x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) {
+ int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
+#else
+ x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
+ (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE],
+ x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+ }
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+ typedef mma_int_A_I16K8 mma_A;
+ typedef mma_int_B_J8K8 mma_B;
+ typedef mma_int_C_I16J8 mma_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+ const half2 * y_ds = (const half2 *) y;
+
+ mma_A A[ntx][WARP_SIZE/QI8_0];
+ float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
+
+ const int i0 = (threadIdx.y/ntx)*rows_per_warp;
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+ const int k0 = k00 + k01;
+
+ A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
+ }
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+ const int k0 = k00 + k01;
+
+ dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
+ }
+ }
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+ mma_B B;
+ float dB[mma_C::ne/2];
+
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
+
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+ } else {
+ dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C;
+ C.mma_K8(A[n][k01/QI8_0], B);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
+ }
+ }
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
+ (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+ x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+ typedef mma_int_A_I16K8 mma_A;
+ typedef mma_int_B_J8K8 mma_B;
+ typedef mma_int_C_I16J8 mma_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_dm = (const half2 *) y;
+
+ mma_A A[ntx][WARP_SIZE/QI8_1];
+ float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
+
+ const int i0 = (threadIdx.y/ntx)*rows_per_warp;
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+ const int k0 = k00 + k01;
+
+ A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
+ }
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+ const int k0 = k00 + k01;
+
+ dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
+ }
+ }
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+ mma_B B;
+ float2 dsB[mma_C::ne/2];
+
+ B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
+
+ dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C;
+ C.mma_K8(A[n][k01/QI8_1], B);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
+ }
+ }
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+ constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
+ &x_qs[i*(2*WARP_SIZE + 1) + k0],
+ &y_qs[j*MMQ_TILE_Y_K + k01],
+ &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
+ y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+#ifdef INT8_MMA_AVAILABLE
+
+ typedef mma_int_A_I16K4 mma_A;
+ typedef mma_int_A_I16K8 mma_A_K8;
+ typedef mma_int_B_J8K4 mma_B;
+ typedef mma_int_C_I16J8 mma_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+ mma_A A[ntx][8];
+ float dA[ntx][mma_C::ne/2][8];
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+ const int k0 = k00 + k01;
+
+ ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+ }
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
+ const int k0 = k00 + k01;
+
+ dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
+ }
+ }
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+ mma_B B[2];
+ float dB[mma_C::ne/2];
+
+ B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
+ B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
+
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C[2];
+ C[0].mma_K4(A[n][k01/4 + 0], B[0]);
+ C[1].mma_K4(A[n][k01/4 + 1], B[1]);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
+ }
+ }
+ }
+ }
+#else
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x % QI2_K;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) {
+ int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride;
+
+ const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
+
+#pragma unroll
+ for (int l = 0; l < QR2_K; ++l) {
+ const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
+
+ const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int sc_m = bxi->scales[kqsx];
+#ifdef FAST_FP16_AVAILABLE
+ const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
+#else
+ const float2 bxi_dmf = __half22float2(bxi->dm);
+ const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
+#endif // FAST_FP16_AVAILABLE
+
+#ifdef INT8_MMA_AVAILABLE
+ x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
+#else
+ x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+ float2 y_df[mmq_x/nwarps];
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
+ }
+
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ if (k01 < WARP_SIZE/2) {
+ constexpr int ns = 2;
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+ &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+ &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+ &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+ } else {
+ constexpr int ns = 1;
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+ &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+ &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+ &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+ }
+ }
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+#ifdef INT8_MMA_AVAILABLE
+
+ typedef mma_int_A_I16K4 mma_A;
+ typedef mma_int_A_I16K8 mma_A_K8;
+ typedef mma_int_B_J8K4 mma_B;
+ typedef mma_int_C_I16J8 mma_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+ mma_A A[ntx][8];
+ float dA[ntx][mma_C::ne/2][8];
+ float mA[ntx][mma_C::ne/2][8];
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+ const int k0 = k00 + k01;
+
+ ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
+ }
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
+ const int k0 = k00 + k01;
+
+ const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
+
+ dA[n][l][k01/(QI8_1/2)] = dm.x;
+ mA[n][l][k01/(QI8_1/2)] = dm.y;
+ }
+ }
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+ float2 dB[mma_C::ne/2];
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
+
+ dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
+ }
+
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+ mma_B B[2];
+
+ B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
+ B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+
+ mma_C Cm[2];
+ if (k01 >= WARP_SIZE * 3/4) {
+ mma_A A1;
+ A1.x[0] = 0x01010101;
+ A1.x[1] = 0x01010101;
+ Cm[0].mma_K4(A1, B[0]);
+ Cm[1].mma_K4(A1, B[1]);
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C Cd[2];
+
+ Cd[0].mma_K4(A[n][k01/4 + 0], B[0]);
+ Cd[1].mma_K4(A[n][k01/4 + 1], B[1]);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
+ if (k01 >= WARP_SIZE * 3/4) {
+ tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
+ }
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
+ }
+ }
+ }
+
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
+ float2 sB[mma_C::ne/2];
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
+
+ sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
+ }
+ }
+ }
+ }
+#else
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+ int * x_sc = (int *) (x_df + txs.dm);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x % QI3_K;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) {
+ int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+
+ const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
+ const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
+
+#pragma unroll
+ for (int l = 0; l < QR3_K; ++l) {
+ const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
+
+ const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
+ const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
+
+ const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k;
+#endif // INT8_MMA_AVAILABLE
+ }
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
+ int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+
+ const int ksc = threadIdx.x % (WARP_SIZE/8);
+
+ const int ksc_low = ksc % (QI3_K/8);
+ const int shift_low = 4 * (ksc / (QI3_K/8));
+ const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
+
+ const int ksc_high = QI3_K/8;
+ const int shift_high = 2 * ksc;
+ const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
+
+ const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
+
+#ifdef INT8_MMA_AVAILABLE
+ const int8_t * sc8 = (const int8_t *) &sc;
+ const float d = bxi->d;
+
+#pragma unroll
+ for (int l = 0; l < sizeof(int); ++l) {
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
+ }
+#else
+ x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+#ifndef INT8_MMA_AVAILABLE
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
+ int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+
+ x_df[i] = bxi->d;
+ }
+#endif // INT8_MMA_AVAILABLE
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
+ const int * x_sc = (const int *) x_df + txs.dm;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4;
+
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
+ &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
+ x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) {
+ // scale arrangement after the following two lines:
+ // - ksc == 0: sc0, sc1, sc2, sc3
+ // - ksc == 1: sc4, sc5, sc6, sc7
+ // - ksc == 2: m0, m1, m2, m3
+ // - ksc == 3: m4, m5, m6, m7
+ return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits
+ ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + txs.qs);
+ int * x_sc = (int *) (x_dm + txs.dm);
+#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+ const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
+#else
+ x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+#ifdef INT8_MMA_AVAILABLE
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
+ int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+
+ const int * scales = (const int *) bxi->scales;
+ const int ksc = threadIdx.x % (WARP_SIZE/16);
+
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
+
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
+ const uint8_t * m8 = (const uint8_t *) &m32;
+
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+
+#pragma unroll
+ for (int l = 0; l < sizeof(int); ++l) {
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+ }
+ }
+
+#else
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) {
+ int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+
+ x_dm[i] = bxi->dm;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+ int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
+
+ const int * scales = (const int *) bxi->scales;
+
+ const int ksc = threadIdx.x % (WARP_SIZE/8);
+ const int scales8 = unpack_scales_q45_K(scales, ksc);
+
+ x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
+ }
+#endif // INT8_MMA_AVAILABLE
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+ const int * x_sc = (const int *) x_dm + txs.dm;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16);
+
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
+ &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+ x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + txs.qs);
+ int * x_sc = (int *) (x_dm + txs.dm);
+#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+ const int ky = QR5_K*threadIdx.x;
+
+ const int ql = get_int_b4(bxi->qs, threadIdx.x);
+ const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+ const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+ const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4));
+ const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010;
+ const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010;
+
+ const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
+ const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
+ x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+#ifdef INT8_MMA_AVAILABLE
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
+ int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+ const int * scales = (const int *) bxi->scales;
+ const int ksc = threadIdx.x % (WARP_SIZE/16);
+
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
+
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
+ const uint8_t * m8 = (const uint8_t *) &m32;
+
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+
+#pragma unroll
+ for (int l = 0; l < sizeof(int); ++l) {
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+ }
+ }
+
+#else
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) {
+ int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+ x_dm[i] = bxi->dm;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
+ int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+ const int * scales = (const int *) bxi->scales;
+
+ const int ksc = threadIdx.x % (WARP_SIZE/8);
+ const int scales8 = unpack_scales_q45_K(scales, ksc);
+
+ x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
+ }
+#endif // INT8_MMA_AVAILABLE
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+ const int * x_sc = (const int *) x_dm + txs.dm;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
+
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
+ &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+ x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+ int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+ int * x_sc = (int *) (x_df + txs.dm);
+#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
+
+ const int ql = get_int_b2(bxi->ql, threadIdx.x);
+ const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+ const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+ const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4));
+ const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030;
+ const int qh1 = (qh >> ((threadIdx.x & 0x08) >> 2)) & 0x30303030;
+
+ const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
+ const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+ x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
+ int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d;
+#else
+ x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+ int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
+#else
+ x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
+ const int * x_sc = (const int *) x_df + txs.dm;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]);
+
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
+ &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
+ x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y, int nwarps>
+static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+#ifdef INT8_MMA_AVAILABLE
+
+ typedef mma_int_A_I16K4 mma_A;
+ typedef mma_int_B_J8K4 mma_B;
+ typedef mma_int_C_I16J8 mma_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+ const int * x_sc = (const int *) x_df + WARP_SIZE/QI6_K;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+ mma_A A[ntx][8];
+ int scA[ntx][mma_C::ne/2][8];
+ float dA[ntx][mma_C::ne/2];
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+ const int k0 = k00 + k01;
+
+ A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
+ A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
+ }
+
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+ const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
+ const int8_t * sc = (const int8_t *) &sc_packed;
+
+#pragma unroll
+ for (int ksc = 0; ksc < sizeof(int); ++ksc) {
+ scA[n][l][k01/4 + ksc] = sc[ksc];
+ }
+ }
+ }
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+ dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
+ }
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+ float tmp[ntx][mma_C::ne] = {{0.0f}};
+
+#pragma unroll
+ for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+ mma_B B[2];
+ float dB[mma_C::ne/2];
+
+ B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
+ B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne/2; ++l) {
+ const int j = j0 + mma_C::get_j(l);
+
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ mma_C C[2];
+ C[0].mma_K4(A[n][k01/4 + 0], B[0]);
+ C[1].mma_K4(A[n][k01/4 + 1], B[1]);
+
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
+ }
+ }
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
+ }
+ }
+ }
+#else
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+ NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kbx = threadIdx.x / QI4_NL;
+ const int kqsx = threadIdx.x % QI4_NL;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
+
+ const int aux_q4 = get_int_b2(bxi->qs, kqsx);
+ const int2 v = get_int_from_table_16(aux_q4);
+ const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
+ int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d);
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x % (QI2_XXS/2);
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) {
+ int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride;
+
+ const int q2 = get_int_b2(bxi->qs, 2*kqsx+0);
+ const uint8_t * aux8 = (const uint8_t *) &q2;
+ const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1);
+
+#pragma unroll
+ for (int l = 0; l < QR2_XXS; ++l) {
+ const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
+ const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
+
+ const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
+ const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
+
+ const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
+ const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int ls = aux32 >> 28;
+ const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x % (QI2_XS/2);
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) {
+ int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride;
+
+ const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+ const uint16_t * q2 = (const uint16_t *) &q2_packed;
+
+ #pragma unroll
+ for (int l = 0; l < QR2_XS; ++l) {
+ const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
+
+ const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
+ const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int ls = bxi->scales[kqsx];
+ const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#else
+ x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x % (QI2_S/2);
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) {
+ int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride;
+
+ const int qs_packed = get_int_b2(bxi->qs, kqsx);
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ const int qh = bxi->qh[kqsx];
+
+ const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx);
+ const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
+
+#pragma unroll
+ for (int l = 0; l < QR2_S; ++l) {
+ const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300)));
+
+ const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
+ const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
+
+ const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
+ const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int ls = bxi->scales[kqsx];
+ const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#else
+ x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x % (QI3_XXS/2);
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) {
+ int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride;
+
+ const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+ const uint8_t * q3 = (const uint8_t *) &q3_packed;
+ const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx);
+
+#pragma unroll
+ for (int l = 0; l < QR3_XXS; ++l) {
+ const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
+
+ const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
+
+ const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
+ const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int ls = aux32 >> 28;
+ const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x % (QI3_S/2);
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) {
+ int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride;
+
+ const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ const int qh = bxi->qh[kqsx];
+
+ const int signs_packed_32 = get_int_b2(bxi->signs, kqsx);
+ const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
+
+#pragma unroll
+ for (int l = 0; l < QR3_S; ++l) {
+ const int2 grid_pos = make_int2(
+ iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)],
+ iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]);
+
+ const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
+ const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
+
+ const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
+ const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
+ const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d;
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_ds = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kqsx = threadIdx.x % QI1_S;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) {
+ int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride;
+
+ const int qs_packed = get_int_b2(bxi->qs, kqsx);
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ const int qh = bxi->qh[kqsx];
+
+ #pragma unroll
+ for (int l = 0; l < QR1_S/2; ++l) {
+ const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)];
+
+ const int grid0 = (grid >> 0) & 0x0F0F0F0F;
+ const int grid1 = (grid >> 4) & 0x0F0F0F0F;
+
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0;
+ x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+ const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
+ const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
+#else
+ x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+ const int kbx = 0; // threadIdx.x / QI4_XS
+ const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + threadIdx.y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx;
+
+ const int aux_q4 = get_int_b4(bxi->qs, kqsx);
+ const int2 v = get_int_from_table_16(aux_q4);
+ const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
+#ifdef INT8_MMA_AVAILABLE
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
+#else
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
+ x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
+#endif // INT8_MMA_AVAILABLE
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+ int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
+
+ const float d = __half2float(bxi->d);
+
+ const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
+ | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
+
+#ifdef INT8_MMA_AVAILABLE
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
+#else
+ x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
+#endif // INT8_MMA_AVAILABLE
+ }
+}
+
+template<int mmq_x, int mmq_y, int nwarps, bool need_check>
+static __device__ __forceinline__ void mmq_write_back_dp4a(
+ const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (j > j_max) {
+ return;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ if (need_check && i > i_max) {
+ continue;
+ }
+
+ dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+ }
+ }
+}
+
+template<int mmq_x, int mmq_y, int nwarps, bool need_check>
+static __device__ __forceinline__ void mmq_write_back_mma(
+ const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
+
+ typedef mma_int_C_I16J8 mma_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+ const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
+#ifdef INT8_MMA_AVAILABLE
+ static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
+#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < mma_C::ne; ++l) {
+ const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
+
+ if (j > j_max) {
+ continue;
+ }
+
+ const int i = i0 + n*mma_C::I + mma_C::get_i(l);
+
+ if (need_check && i > i_max) {
+ continue;
+ }
+
+ dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
+ }
+ }
+ }
+}
+
+// -------------------------------------------------------------------------------------------------------------------------------------
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
+struct mmq_type_traits;
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
+ static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_DS4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
+ static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
+ static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
+ static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
+ static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
+ static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
+ static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
+ static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
+ static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
+ static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XXS> {
+ static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_XS> {
+ static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ2_S> {
+ static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_XXS> {
+ static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ3_S> {
+ static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ1_S> {
+ static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
+ static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
+ static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
+static __device__ void mul_mat_q_process_tile(
+ const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
+ const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
+ const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
+
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
+ constexpr int mmq_y = get_mmq_y_device();
+ constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
+
+ extern __shared__ char data_mul_mat_q[];
+ int * tile_y = (int *) data_mul_mat_q;
+ int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
+
+#ifdef INT8_MMA_AVAILABLE
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
+ constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
+#else
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
+ constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
+#endif // INT8_MMA_AVAILABLE
+
+ constexpr int blocks_per_iter = MMQ_ITER_K / qk;
+
+ float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
+
+ const int tile_x_max_i = ne01 - it*mmq_y - 1;
+ const int tile_y_max_j = ne11 - jt*mmq_x - 1;
+
+ const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
+
+ for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
+ load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
+
+ {
+ const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
+#pragma unroll
+ for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
+ int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+
+ tile_y[l] = by0[l];
+ }
+ }
+
+ __syncthreads();
+
+ vec_dot(tile_x, tile_y, sum, 0);
+
+ __syncthreads();
+
+ {
+ const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
+#pragma unroll
+ for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
+ int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+
+ tile_y[l] = by0[l];
+ }
+ }
+
+ __syncthreads();
+
+ vec_dot(tile_x, tile_y, sum, WARP_SIZE);
+
+ __syncthreads();
+ }
+
+ if (fixup) {
+ write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
+ } else {
+ write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
+ }
+}
+
+
+// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
+
+template <ggml_type type, int mmq_x, int nwarps, bool need_check>
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2)
+ __launch_bounds__(WARP_SIZE*nwarps, 2)
+#endif // defined(RDNA3) || defined(RDNA2)
+#else
+#if __CUDA_ARCH__ >= CC_VOLTA
+ __launch_bounds__(WARP_SIZE*nwarps, 1)
+#else
+ __launch_bounds__(WARP_SIZE*nwarps, 2)
+#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+static __global__ void mul_mat_q(
+ const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
+ const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
+
+ // Skip unused template specializations for faster compilation:
+ if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
+ constexpr int mmq_y = get_mmq_y_device();
+
+ // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
+#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
+ {
+ constexpr bool fixup = false;
+ mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+ (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+ blockIdx.x, blockIdx.y, 0, ne00/qk);
+ return;
+ }
+#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
+
+ const int64_t blocks_per_ne00 = ne00 / qk;
+ constexpr int blocks_per_iter = MMQ_ITER_K / qk;
+
+ const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
+ const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
+
+ // kbc == k block continuous, current index in continuous ijk space.
+ int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x;
+ int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;
+
+ kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
+ kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
+
+ // kb0 == k index when doing the matrix multiplication for an output tile.
+ int kb0_start = kbc % blocks_per_ne00;
+ int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
+ while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
+ const int jt = kbc / (blocks_per_ne00*nty); // j index of current tile.
+ const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
+
+ constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
+ mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+ (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+ it, jt, kb0_start, kb0_stop);
+
+ kbc += blocks_per_ne00;
+ kbc -= kbc % blocks_per_ne00;
+
+ kb0_start = 0;
+ kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
+ }
+
+ if (kbc >= kbc_stop) {
+ return;
+ }
+
+ const int jt = kbc / (blocks_per_ne00*nty);
+ const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
+
+ constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
+ mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
+ (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+ it, jt, kb0_start, kb0_stop);
+}
+
+
+template <ggml_type type, int mmq_x, int nwarps, bool need_check>
+static __global__ void mul_mat_q_stream_k_fixup(
+ float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
+
+ constexpr int mmq_y = get_mmq_y_device();
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
+ constexpr int blocks_per_iter = MMQ_ITER_K / qk;
+ const int64_t blocks_per_ne00 = ne00 / qk;
+
+ float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
+
+ const int ntx = (ne11 + mmq_x - 1) / mmq_x;
+ const int nty = (ne01 + mmq_y - 1) / mmq_y;
+
+ bool any_fixup = false;
+
+ const int bidx_start = ((blockIdx.y*nty + blockIdx.x) * block_num_mmq) / (gridDim.y*gridDim.x);
+ const int bidx_stop = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x);
+
+ int64_t kbc_0;
+ int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq;
+
+ for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
+ kbc_0 = kbc_stop_0;
+ kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
+
+ const int64_t kbc = kbc_0 - (kbc_0 % blocks_per_ne00) % blocks_per_iter;
+ const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter;
+
+ // Skip fixup tile if the MMQ CUDA block never wrote anything to it:
+ if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
+ continue;
+ }
+
+ const int jt = kbc_stop / (blocks_per_ne00*nty);
+ const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
+
+ // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
+ if (it != blockIdx.x || jt != blockIdx.y) {
+ continue;
+ }
+
+ any_fixup = true;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
+ }
+ }
+ }
+
+ if (!any_fixup) {
+ return;
+ }
+
+ dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
+
+ const int i_max = ne01 - blockIdx.x*mmq_y - 1;
+ const int j_max = ne11 - blockIdx.y*mmq_x - 1;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (j > j_max) {
+ return;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ if (need_check && i > i_max) {
+ continue;
+ }
+
+ dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+ }
+ }
+}
+
+struct mmq_args {
+ const char * x; const char * y; float * dst;
+ int64_t ne00; int64_t ne01; int64_t stride01;
+ int64_t ne10; int64_t ne11; int64_t stride11;
+ int64_t ne0;
+};
+
+template<ggml_type type>
+static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
+ const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
+ const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
+ const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+ const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
+ return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
+}
+
+template <ggml_type type, int mmq_x>
+static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
+ const int id = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[id].cc;
+ const int nsm = ggml_cuda_info().devices[id].nsm;
+ const int mmq_y = get_mmq_y_host(cc);
+
+ const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
+
+ const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
+
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+ static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+ if (!shmem_limit_raised[id]) {
+ CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+ CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+ shmem_limit_raised[id] = true;
+ }
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+
+ const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
+ const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
+ const dim3 block_nums_xy_tiling(nty, ntx, 1);
+
+ const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
+ if (!use_stream_k) {
+ if (args.ne01 % mmq_y == 0) {
+ constexpr bool need_check = false;
+ mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
+ (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+ } else {
+ constexpr bool need_check = true;
+ mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
+ (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+ }
+ return;
+ }
+
+ const dim3 block_nums_mmq(nsm, 1, 1);
+
+ ggml_cuda_pool & pool = ctx.pool(id);
+ ggml_cuda_pool_alloc<float> tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
+
+ if (args.ne01 % mmq_y == 0) {
+ constexpr bool need_check = false;
+
+ mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
+ (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+
+ mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
+ (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
+ } else {
+ constexpr bool need_check = true;
+
+ mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
+ (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+
+ mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
+ (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
+ }
+}
+
+template <ggml_type type>
+void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
+ const int id = ggml_cuda_get_device();
+ const int nsm = ggml_cuda_info().devices[id].nsm;
+ const int cc = ggml_cuda_info().devices[id].cc;
+ const int smpbo = ggml_cuda_info().devices[id].smpbo;
+
+ const int mmq_x_max = get_mmq_x_max_host(cc);
+ const int mmq_y = get_mmq_y_host(cc);
+ const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
+ const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
+
+ int mmq_x_best = 0;
+ int nparts_best = INT_MAX;
+
+ for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
+ const int granularity = mmq_get_granularity_host(mmq_x, cc);
+
+ if (mmq_x % granularity != 0 || mmq_get_shmem<type>(mmq_x, mmq_y, cc) > smpbo) {
+ continue;
+ }
+
+ const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
+ const int nwaves_xy_tiling = ntiles_x*block_num_y;
+ const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
+
+ if (nparts < nparts_best) {
+ mmq_x_best = mmq_x;
+ nparts_best = nparts;
+ }
+ }
+
+ switch (mmq_x_best) {
+ case 8:
+ launch_mul_mat_q<type, 8>(ctx, args, stream);
+ break;
+ case 16:
+ launch_mul_mat_q<type, 16>(ctx, args, stream);
+ break;
+ case 24:
+ launch_mul_mat_q<type, 24>(ctx, args, stream);
+ break;
+ case 32:
+ launch_mul_mat_q<type, 32>(ctx, args, stream);
+ break;
+ case 40:
+ launch_mul_mat_q<type, 40>(ctx, args, stream);
+ break;
+ case 48:
+ launch_mul_mat_q<type, 48>(ctx, args, stream);
+ break;
+ case 56:
+ launch_mul_mat_q<type, 56>(ctx, args, stream);
+ break;
+ case 64:
+ launch_mul_mat_q<type, 64>(ctx, args, stream);
+ break;
+ case 72:
+ launch_mul_mat_q<type, 72>(ctx, args, stream);
+ break;
+ case 80:
+ launch_mul_mat_q<type, 80>(ctx, args, stream);
+ break;
+ case 88:
+ launch_mul_mat_q<type, 88>(ctx, args, stream);
+ break;
+ case 96:
+ launch_mul_mat_q<type, 96>(ctx, args, stream);
+ break;
+ case 104:
+ launch_mul_mat_q<type, 104>(ctx, args, stream);
+ break;
+ case 112:
+ launch_mul_mat_q<type, 112>(ctx, args, stream);
+ break;
+ case 120:
+ launch_mul_mat_q<type, 120>(ctx, args, stream);
+ break;
+ case 128:
+ launch_mul_mat_q<type, 128>(ctx, args, stream);
+ break;
+ default:
+ fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
+ GGML_ASSERT(false);
+ break;
+ }
+}
+
+#define DECL_MMQ_CASE(type) \
+ template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
+
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
+extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
+
+// -------------------------------------------------------------------------------------------------------------------------
+
+void ggml_cuda_op_mul_mat_q(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream);
+
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
new file mode 100644
index 00000000..b44000cd
--- /dev/null
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -0,0 +1,447 @@
+#include "mmvq.cuh"
+#include "vecdotq.cuh"
+
+typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
+
+static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
+ return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
+ type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
+ type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
+ type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
+ type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
+ type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
+ type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
+ type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
+ type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
+ type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
+ type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
+ type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
+ type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
+ type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
+ type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
+ type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
+ type == GGML_TYPE_IQ1_BN ? vec_dot_iq1_bn_q8_1 :
+ type == GGML_TYPE_IQ2_BN ? vec_dot_iq2_bn_q8_1 :
+ type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
+ type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
+ type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
+ nullptr;
+}
+
+static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
+ return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
+ type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
+ type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
+ type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
+ type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
+ type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
+ type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
+ type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
+ type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
+ type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
+ type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
+ type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ :
+ type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ :
+ type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ :
+ type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ :
+ type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ :
+ type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ :
+ 1;
+}
+
+template <ggml_type type, int ncols_y>
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+// tell the compiler to use as many registers as it wants, see nwarps definition below
+__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void mul_mat_vec_q(
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
+ constexpr int qi = ggml_cuda_type_traits<type>::qi;
+ constexpr int vdr = get_vdr_mmvq(type);
+
+ constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
+
+#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
+ constexpr int nwarps = 1;
+ constexpr int rows_per_cuda_block = 1;
+#else
+ constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
+ constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
+
+ const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+ const int row0 = rows_per_cuda_block*blockIdx.x;
+ const int blocks_per_row_x = ncols_x / qk;
+ const int blocks_per_col_y = nrows_y / QK8_1;
+ constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
+
+// partial sum for each thread
+ float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
+
+ const block_q8_1 * y = (const block_q8_1 *) vy;
+
+ for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
+ const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
+
+ // x block quant index when casting the quants to int
+ const int kqs = vdr * (tid % (qi/vdr));
+
+#pragma unroll
+ for (int j = 0; j < ncols_y; ++j) {
+#pragma unroll
+ for (int i = 0; i < rows_per_cuda_block; ++i) {
+ tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
+ }
+ }
+ }
+
+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
+ if (threadIdx.y > 0) {
+#pragma unroll
+ for (int j = 0; j < ncols_y; ++j) {
+#pragma unroll
+ for (int i = 0; i < rows_per_cuda_block; ++i) {
+ tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
+ }
+ }
+ }
+ __syncthreads();
+ if (threadIdx.y > 0) {
+ return;
+ }
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int j = 0; j < ncols_y; ++j) {
+#pragma unroll
+ for (int i = 0; i < rows_per_cuda_block; ++i) {
+#pragma unroll
+ for (int l = 0; l < nwarps-1; ++l) {
+ tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
+ }
+ tmp[j][i] = warp_reduce_sum(tmp[j][i]);
+ }
+
+ if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
+ dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
+ }
+ }
+}
+
+template <ggml_type type>
+static void mul_mat_vec_q_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
+ GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
+
+ int id = ggml_cuda_get_device();
+
+ int64_t nwarps = 1;
+ int64_t rows_per_cuda_block = 1;
+
+ if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
+ switch(ncols_y) {
+ case 1:
+ nwarps = 4;
+ rows_per_cuda_block = 1;
+ break;
+ case 2:
+ case 3:
+ case 4:
+ nwarps = 4;
+ rows_per_cuda_block = 2;
+ break;
+ case 5:
+ case 6:
+ case 7:
+ case 8:
+ nwarps = 2;
+ rows_per_cuda_block = 2;
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+ }
+ const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
+ const dim3 block_nums(nblocks, 1, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ switch (ncols_y) {
+ case 1:
+ mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ break;
+ case 2:
+ mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ break;
+ case 3:
+ mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ break;
+ case 4:
+ mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ break;
+ case 5:
+ mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ break;
+ case 6:
+ mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ break;
+ case 7:
+ mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ break;
+ case 8:
+ mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+}
+
+static void mul_mat_vec_q4_0_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q4_1_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q5_0_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q5_1_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q8_0_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q2_K_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q3_K_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q4_K_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q5_K_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q6_K_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq2_xxs_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq2_xs_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq2_s_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq3_xxs_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq1_s_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq1_m_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq1_bn_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ1_BN>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq2_bn_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ2_BN>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq4_nl_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq4_xs_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq3_s_q8_1_cuda(
+ const void * vx, const void * vy, float * dst,
+ const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+ mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+void ggml_cuda_op_mul_mat_vec_q(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t row_diff = row_high - row_low;
+
+ const int64_t ne10 = src1->ne[0];
+ GGML_ASSERT(ne10 % QK8_1 == 0);
+
+ const int64_t ne0 = dst->ne[0];
+
+ int id = ggml_cuda_get_device();
+
+ // the main device has a larger memory buffer to hold the results from all GPUs
+ // nrows_dst == nrows of the matrix that the kernel writes into
+ const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
+
+ switch (src0->type) {
+ case GGML_TYPE_Q4_0:
+ mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q4_1:
+ mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q5_0:
+ mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q5_1:
+ mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q8_0:
+ mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q2_K:
+ mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q3_K:
+ mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q4_K:
+ mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q5_K:
+ mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q6_K:
+ mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_IQ2_XXS:
+ mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_IQ2_XS:
+ mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_IQ2_S:
+ mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_IQ3_XXS:
+ mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_IQ1_S:
+ mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_IQ1_M:
+ mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_IQ1_BN:
+ mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_IQ2_BN:
+ mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_IQ4_NL:
+ mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_IQ4_XS:
+ mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ case GGML_TYPE_IQ3_S:
+ mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+
+ GGML_UNUSED(src1);
+ GGML_UNUSED(dst);
+ GGML_UNUSED(src1_ddf_i);
+ GGML_UNUSED(src1_ncols);
+ GGML_UNUSED(src1_padded_row_size);
+}
diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh
new file mode 100644
index 00000000..d9e42fdd
--- /dev/null
+++ b/ggml/src/ggml-cuda/mmvq.cuh
@@ -0,0 +1,9 @@
+#include "common.cuh"
+
+#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
+
+void ggml_cuda_op_mul_mat_vec_q(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream);
diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu
new file mode 100644
index 00000000..30866d51
--- /dev/null
+++ b/ggml/src/ggml-cuda/norm.cu
@@ -0,0 +1,221 @@
+#include "norm.cuh"
+
+template <int block_size>
+static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ const int tid = threadIdx.x;
+
+ float2 mean_var = make_float2(0.f, 0.f);
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xi = x[row*ncols + col];
+ mean_var.x += xi;
+ mean_var.y += xi * xi;
+ }
+
+ // sum up partial sums
+ mean_var = warp_reduce_sum(mean_var);
+ if (block_size > WARP_SIZE) {
+ __shared__ float2 s_sum[32];
+ int warp_id = threadIdx.x / WARP_SIZE;
+ int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = mean_var;
+ }
+ __syncthreads();
+ mean_var = s_sum[lane_id];
+ mean_var = warp_reduce_sum(mean_var);
+ }
+
+ const float mean = mean_var.x / ncols;
+ const float var = mean_var.y / ncols - mean * mean;
+ const float inv_std = rsqrtf(var + eps);
+
+ for (int col = tid; col < ncols; col += block_size) {
+ dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
+ }
+}
+
+template <int block_size>
+static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
+ // blockIdx.x: num_groups idx
+ // threadIdx.x: block_size idx
+ int start = blockIdx.x * group_size;
+ int end = start + group_size;
+
+ start += threadIdx.x;
+
+ if (end >= ne_elements) {
+ end = ne_elements;
+ }
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int j = start; j < end; j += block_size) {
+ tmp += x[j];
+ }
+
+ tmp = warp_reduce_sum(tmp);
+ if (block_size > WARP_SIZE) {
+ __shared__ float s_sum[32];
+ int warp_id = threadIdx.x / WARP_SIZE;
+ int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = tmp;
+ }
+ __syncthreads();
+ tmp = s_sum[lane_id];
+ tmp = warp_reduce_sum(tmp);
+ }
+
+ float mean = tmp / group_size;
+ tmp = 0.0f;
+
+ for (int j = start; j < end; j += block_size) {
+ float xi = x[j] - mean;
+ dst[j] = xi;
+ tmp += xi * xi;
+ }
+
+ tmp = warp_reduce_sum(tmp);
+ if (block_size > WARP_SIZE) {
+ __shared__ float s_sum[32];
+ int warp_id = threadIdx.x / WARP_SIZE;
+ int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = tmp;
+ }
+ __syncthreads();
+ tmp = s_sum[lane_id];
+ tmp = warp_reduce_sum(tmp);
+ }
+
+ float variance = tmp / group_size;
+ float scale = rsqrtf(variance + eps);
+ for (int j = start; j < end; j += block_size) {
+ dst[j] *= scale;
+ }
+}
+
+template <int block_size>
+static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ const int tid = threadIdx.x;
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xi = x[row*ncols + col];
+ tmp += xi * xi;
+ }
+
+ // sum up partial sums
+ tmp = warp_reduce_sum(tmp);
+ if (block_size > WARP_SIZE) {
+ __shared__ float s_sum[32];
+ int warp_id = threadIdx.x / WARP_SIZE;
+ int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = tmp;
+ }
+ __syncthreads();
+ tmp = s_sum[lane_id];
+ tmp = warp_reduce_sum(tmp);
+ }
+
+ const float mean = tmp / ncols;
+ const float scale = rsqrtf(mean + eps);
+
+ for (int col = tid; col < ncols; col += block_size) {
+ dst[row*ncols + col] = scale * x[row*ncols + col];
+ }
+}
+
+static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+ GGML_ASSERT(ncols % WARP_SIZE == 0);
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+ }
+}
+
+static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
+ static const float eps = 1e-6f;
+ if (group_size < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+ }
+}
+
+static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+ GGML_ASSERT(ncols % WARP_SIZE == 0);
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+ }
+}
+
+void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+}
+
+void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ int num_groups = dst->op_params[0];
+ int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
+ group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], group_size, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+}
diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh
new file mode 100644
index 00000000..431a8f74
--- /dev/null
+++ b/ggml/src/ggml-cuda/norm.cuh
@@ -0,0 +1,7 @@
+#include "common.cuh"
+
+void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu
new file mode 100644
index 00000000..aba539e8
--- /dev/null
+++ b/ggml/src/ggml-cuda/pad.cu
@@ -0,0 +1,49 @@
+#include "pad.cuh"
+
+static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
+ // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
+ // blockIdx.y: idx of ne1
+ // blockIDx.x: idx of ne0 / BLOCK_SIZE
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (nidx >= ne0) {
+ return;
+ }
+
+ // operation
+ int offset_dst =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * gridDim.y;
+ if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
+ int offset_src =
+ nidx +
+ blockIdx.y * ne00 +
+ blockIdx.z * ne00 * ne01;
+ dst[offset_dst] = x[offset_src];
+ } else {
+ dst[offset_dst] = 0.0f;
+ }
+}
+
+static void pad_f32_cuda(const float * x, float * dst,
+ const int ne00, const int ne01, const int ne02, const int ne03,
+ const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
+ int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
+ dim3 gridDim(num_blocks, ne1, ne2*ne3);
+ pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
+}
+
+void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+ pad_f32_cuda(src0_d, dst_d,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
+}
diff --git a/ggml/src/ggml-cuda/pad.cuh b/ggml/src/ggml-cuda/pad.cuh
new file mode 100644
index 00000000..8fd386b0
--- /dev/null
+++ b/ggml/src/ggml-cuda/pad.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_PAD_BLOCK_SIZE 256
+
+void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/pool2d.cu b/ggml/src/ggml-cuda/pool2d.cu
new file mode 100644
index 00000000..c6d51e4d
--- /dev/null
+++ b/ggml/src/ggml-cuda/pool2d.cu
@@ -0,0 +1,94 @@
+#include "pool2d.cuh"
+
+template <typename Ti, typename To>
+static __global__ void pool2d_nchw_kernel(
+ const int ih, const int iw, const int oh, const int ow,
+ const int kh, const int kw, const int sh, const int sw,
+ const int ph, const int pw, const int parallel_elements,
+ const Ti* src, To* dst, const enum ggml_op_pool op) {
+ int idx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (idx >= parallel_elements) {
+ return;
+ }
+
+ const int I_HW = ih * iw;
+ const int O_HW = oh * ow;
+ const int nc = idx / O_HW;
+ const int cur_oh = idx % O_HW / ow;
+ const int cur_ow = idx % O_HW % ow;
+ const Ti* i_ptr = src + nc * I_HW;
+ To* o_ptr = dst + nc * O_HW;
+ const int start_h = cur_oh * sh - ph;
+ const int bh = max(0, start_h);
+ const int eh = min(ih, start_h + kh);
+ const int start_w = cur_ow * sw - pw;
+ const int bw = max(0, start_w);
+ const int ew = min(iw, start_w + kw);
+ const To scale = 1. / (kh * kw);
+ To res = 0;
+
+ switch (op) {
+ case GGML_OP_POOL_AVG: res = 0; break;
+ case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
+ default: assert(false);
+ }
+
+ for (int i = bh; i < eh; i += 1) {
+ for (int j = bw; j < ew; j += 1) {
+#if __CUDA_ARCH__ >= 350
+ Ti cur = __ldg(i_ptr + i * iw + j);
+#else
+ Ti cur = i_ptr[i * iw + j];
+#endif
+ switch (op) {
+ case GGML_OP_POOL_AVG: res += cur * scale; break;
+ case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
+ default: assert(false);
+ }
+ }
+ }
+ o_ptr[cur_oh * ow + cur_ow] = res;
+}
+
+static void pool2d_nchw_kernel_f32_f32_cuda(
+ const int ih, const int iw, const int oh, const int ow,
+ const int kh, const int kw, const int sh, const int sw,
+ const int ph, const int pw, const int parallel_elements,
+ const float * src, float * dst, const enum ggml_op_pool op,
+ cudaStream_t stream) {
+
+ const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
+ dim3 block_nums(num_blocks);
+ pool2d_nchw_kernel<<<block_nums, CUDA_POOL2D_BLOCK_SIZE, 0, stream>>>(ih, iw, oh, ow, kh, kw, sh, sw, ph, pw, parallel_elements, src, dst, op);
+}
+
+void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int32_t * opts = (const int32_t *)dst->op_params;
+ enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
+ const int k0 = opts[1];
+ const int k1 = opts[2];
+ const int s0 = opts[3];
+ const int s1 = opts[4];
+ const int p0 = opts[5];
+ const int p1 = opts[6];
+
+ const int64_t IH = src0->ne[1];
+ const int64_t IW = src0->ne[0];
+
+ const int64_t N = dst->ne[3];
+ const int64_t OC = dst->ne[2];
+ const int64_t OH = dst->ne[1];
+ const int64_t OW = dst->ne[0];
+
+ const int parallel_elements = N * OC * OH * OW;
+
+ pool2d_nchw_kernel_f32_f32_cuda(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_d, dst_d, op, stream);
+}
diff --git a/ggml/src/ggml-cuda/pool2d.cuh b/ggml/src/ggml-cuda/pool2d.cuh
new file mode 100644
index 00000000..7841292b
--- /dev/null
+++ b/ggml/src/ggml-cuda/pool2d.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_POOL2D_BLOCK_SIZE 256
+
+void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu
new file mode 100644
index 00000000..aa7f1eff
--- /dev/null
+++ b/ggml/src/ggml-cuda/quantize.cu
@@ -0,0 +1,169 @@
+#include "quantize.cuh"
+#include <cstdint>
+
+static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
+ const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (ix0 >= kx0_padded) {
+ return;
+ }
+
+ const int64_t ix1 = blockIdx.y;
+
+ const int64_t i_padded = ix1*kx0_padded + ix0;
+
+ block_q8_1 * y = (block_q8_1 *) vy;
+
+ const int64_t ib = i_padded / QK8_1; // block index
+ const int64_t iqs = i_padded % QK8_1; // quant index
+
+ const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
+ float amax = fabsf(xi);
+ float sum = xi;
+
+ amax = warp_reduce_max(amax);
+ sum = warp_reduce_sum(sum);
+
+ const float d = amax / 127;
+ const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
+
+ y[ib].qs[iqs] = q;
+
+ if (iqs > 0) {
+ return;
+ }
+
+ reinterpret_cast<half&>(y[ib].ds.x) = d;
+ reinterpret_cast<half&>(y[ib].ds.y) = sum;
+}
+
+template <mmq_q8_1_ds_layout ds_layout>
+static __global__ void quantize_mmq_q8_1(
+ const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
+
+ constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
+ constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
+
+ const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
+
+ if (ix0 >= kx0_padded) {
+ return;
+ }
+
+ const float4 * x4 = (const float4 *) x;
+
+ const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
+
+ block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
+
+ const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
+ const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
+ const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
+
+ // Load 4 floats per thread and calculate max. abs. value between them:
+ const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
+ float amax = fabsf(xi.x);
+ amax = fmaxf(amax, fabsf(xi.y));
+ amax = fmaxf(amax, fabsf(xi.z));
+ amax = fmaxf(amax, fabsf(xi.w));
+
+ // Exchange max. abs. value between vals_per_scale/4 threads.
+#pragma unroll
+ for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
+ amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
+ }
+
+ float sum;
+ if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
+ sum = xi.x + xi.y + xi.z + xi.w;
+
+ // Exchange calculate sum across vals_per_sum/4 threads.
+#pragma unroll
+ for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
+ sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
+ }
+ }
+
+ const float d_inv = 127.0f / amax;
+ char4 q;
+ q.x = roundf(xi.x*d_inv);
+ q.y = roundf(xi.y*d_inv);
+ q.z = roundf(xi.z*d_inv);
+ q.w = roundf(xi.w*d_inv);
+
+ // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
+ char4 * yqs4 = (char4 *) y[ib].qs;
+ yqs4[iqs/4] = q;
+
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
+ if (iqs % 16 != 0 || iqs >= 96) {
+ return;
+ }
+
+ y[ib].d2s6[2 + iqs/16] = sum;
+
+ if (iqs % 64 != 0) {
+ return;
+ }
+
+ const float d = 1.0f / d_inv;
+
+ y[ib].d2s6[iqs/64] = d;
+
+ return;
+ }
+
+ if (iqs % 32 != 0) {
+ return;
+ }
+
+ const float d = 1.0f / d_inv;
+
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
+ y[ib].ds4[iqs/32] = make_half2(d, sum);
+ } else {
+ y[ib].d4[iqs/32] = d;
+ }
+}
+
+void quantize_row_q8_1_cuda(
+ const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
+ const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
+
+ GGML_ASSERT(kx0_padded % QK8_1 == 0);
+
+ const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+ const dim3 num_blocks(block_num_x, kx1*channels, 1);
+ const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
+ quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
+
+ GGML_UNUSED(type_x);
+}
+
+void quantize_mmq_q8_1_cuda(
+ const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
+ const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
+
+ GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
+
+ const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
+ const dim3 num_blocks(block_num_x, kx1, channels);
+ const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
+ switch (mmq_get_q8_1_ds_layout(type_x)) {
+ case MMQ_Q8_1_DS_LAYOUT_D4:
+ quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
+ <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+ break;
+ case MMQ_Q8_1_DS_LAYOUT_DS4:
+ quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
+ <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+ break;
+ case MMQ_Q8_1_DS_LAYOUT_D2S6:
+ quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
+ <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+}
diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh
new file mode 100644
index 00000000..03bf322b
--- /dev/null
+++ b/ggml/src/ggml-cuda/quantize.cuh
@@ -0,0 +1,24 @@
+#pragma once
+
+#include "common.cuh"
+#include "mmq.cuh"
+
+#include <cstdint>
+
+#define CUDA_QUANTIZE_BLOCK_SIZE 256
+#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
+
+static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access.");
+static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
+
+typedef void (*quantize_cuda_t)(
+ const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
+ const ggml_type type_x, cudaStream_t stream);
+
+void quantize_row_q8_1_cuda(
+ const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
+ const ggml_type type_x, cudaStream_t stream);
+
+void quantize_mmq_q8_1_cuda(
+ const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
+ const ggml_type type_x, cudaStream_t stream);
diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu
new file mode 100644
index 00000000..596fb7c1
--- /dev/null
+++ b/ggml/src/ggml-cuda/rope.cu
@@ -0,0 +1,271 @@
+#include "rope.cuh"
+
+struct rope_corr_dims {
+ float v[2];
+};
+
+static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
+ return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static __device__ void rope_yarn(
+ float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
+ float * cos_theta, float * sin_theta) {
+ // Get n-d rotational scaling corrected for extrapolation
+ float theta_interp = freq_scale * theta_extrap;
+ float theta = theta_interp;
+ if (ext_factor != 0.0f) {
+ float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+ // Get n-d magnitude scaling corrected for interpolation
+ mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+ }
+ *cos_theta = cosf(theta) * mscale;
+ *sin_theta = sinf(theta) * mscale;
+}
+
+template<typename T, bool has_ff>
+static __global__ void rope_norm(
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+ const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+ if (i0 >= ne0) {
+ return;
+ }
+
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i0 >= n_dims) {
+ const int i = row*ne0 + i0;
+
+ dst[i + 0] = x[i + 0];
+ dst[i + 1] = x[i + 1];
+
+ return;
+ }
+
+ const int i = row*ne0 + i0;
+ const int i2 = row/p_delta_rows;
+
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+ float cos_theta;
+ float sin_theta;
+
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+ const float x0 = x[i + 0];
+ const float x1 = x[i + 1];
+
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
+ dst[i + 1] = x0*sin_theta + x1*cos_theta;
+}
+
+template<typename T, bool has_ff>
+static __global__ void rope_neox(
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+ const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+ if (i0 >= ne0) {
+ return;
+ }
+
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i0 >= n_dims) {
+ const int i = row*ne0 + i0;
+
+ dst[i + 0] = x[i + 0];
+ dst[i + 1] = x[i + 1];
+
+ return;
+ }
+
+ const int i = row*ne0 + i0/2;
+ const int i2 = row/p_delta_rows;
+
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+ float cos_theta;
+ float sin_theta;
+
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+ const float x0 = x[i + 0];
+ const float x1 = x[i + n_dims/2];
+
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
+ dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+}
+
+template<typename T>
+static void rope_norm_cuda(
+ const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+ GGML_ASSERT(ne0 % 2 == 0);
+ const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+ const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+ const dim3 block_nums(nr, n_blocks_x, 1);
+
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+ if (freq_factors == nullptr) {
+ rope_norm<T, false><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, freq_factors
+ );
+ } else {
+ rope_norm<T, true><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, freq_factors
+ );
+ }
+}
+
+template<typename T>
+static void rope_neox_cuda(
+ const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+ GGML_ASSERT(ne0 % 2 == 0);
+ const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+ const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+ const dim3 block_nums(nr, n_blocks_x, 1);
+
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+ if (freq_factors == nullptr) {
+ rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, freq_factors
+ );
+ } else {
+ rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, freq_factors
+ );
+ }
+}
+
+static void rope_norm_cuda_f16(
+ const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+
+ rope_norm_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
+static void rope_norm_cuda_f32(
+ const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+
+ rope_norm_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
+static void rope_neox_cuda_f16(
+ const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+
+ rope_neox_cuda<half>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
+static void rope_neox_cuda_f32(
+ const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
+) {
+
+ rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
+void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
+
+ const float * src0_d = (const float *)src0->data;
+ const float * src1_d = (const float *)src1->data;
+
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+ GGML_ASSERT(src0->type == dst->type);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t nr = ggml_nrows(src0);
+
+ //const int n_past = ((int32_t *) dst->op_params)[0];
+ const int n_dims = ((int32_t *) dst->op_params)[1];
+ const int mode = ((int32_t *) dst->op_params)[2];
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+ // RoPE alteration for extended context
+ float freq_base;
+ float freq_scale;
+ float ext_factor;
+ float attn_factor;
+ float beta_fast;
+ float beta_slow;
+
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+
+ const bool is_neox = mode & 2;
+
+ const int32_t * pos = (const int32_t *) src1_d;
+
+ const float * freq_factors = nullptr;
+ if (src2 != nullptr) {
+ freq_factors = (const float *) src2->data;
+ }
+
+ rope_corr_dims corr_dims;
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
+
+ // compute
+ if (is_neox) {
+ if (src0->type == GGML_TYPE_F32) {
+ rope_neox_cuda_f32(
+ (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, stream
+ );
+ } else if (src0->type == GGML_TYPE_F16) {
+ rope_neox_cuda_f16(
+ (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, stream
+ );
+ } else {
+ GGML_ASSERT(false);
+ }
+ } else {
+ if (src0->type == GGML_TYPE_F32) {
+ rope_norm_cuda_f32(
+ (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, stream
+ );
+ } else if (src0->type == GGML_TYPE_F16) {
+ rope_norm_cuda_f16(
+ (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, stream
+ );
+ } else {
+ GGML_ASSERT(false);
+ }
+ }
+}
diff --git a/ggml/src/ggml-cuda/rope.cuh b/ggml/src/ggml-cuda/rope.cuh
new file mode 100644
index 00000000..0f787a0b
--- /dev/null
+++ b/ggml/src/ggml-cuda/rope.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_ROPE_BLOCK_SIZE 256
+
+void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu
new file mode 100644
index 00000000..1405e066
--- /dev/null
+++ b/ggml/src/ggml-cuda/scale.cu
@@ -0,0 +1,31 @@
+#include "scale.cuh"
+
+static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ dst[i] = scale * x[i];
+}
+
+static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
+ scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
+}
+
+void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ float scale;
+ memcpy(&scale, dst->op_params, sizeof(float));
+
+ scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
+}
diff --git a/ggml/src/ggml-cuda/scale.cuh b/ggml/src/ggml-cuda/scale.cuh
new file mode 100644
index 00000000..8ff75c82
--- /dev/null
+++ b/ggml/src/ggml-cuda/scale.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_SCALE_BLOCK_SIZE 256
+
+void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu
new file mode 100644
index 00000000..c24abae1
--- /dev/null
+++ b/ggml/src/ggml-cuda/softmax.cu
@@ -0,0 +1,206 @@
+#include "common.cuh"
+#include "softmax.cuh"
+
+template <typename T>
+static __device__ __forceinline__ float t2f32(T val) {
+ return (float) val;
+}
+
+template <>
+__device__ float __forceinline__ t2f32<half>(half val) {
+ return __half2float(val);
+}
+
+template <bool vals_smem, int ncols_template, int block_size_template, typename T>
+static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
+
+ const int tid = threadIdx.x;
+ const int rowx = blockIdx.x;
+ const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
+
+ const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
+
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
+
+ const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
+
+ extern __shared__ float data_soft_max_f32[];
+ float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
+ // shared memory buffer to cache values between iterations:
+ float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
+
+ float max_val = -INFINITY;
+
+#pragma unroll
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
+ const int col = col0 + tid;
+
+ if (ncols_template == 0 && col >= ncols) {
+ break;
+ }
+
+ const int64_t ix = (int64_t)rowx*ncols + col;
+ const int64_t iy = (int64_t)rowy*ncols + col;
+
+ const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
+
+ vals[col] = val;
+ max_val = max(max_val, val);
+ }
+
+ // find the max value in the block
+ max_val = warp_reduce_max(max_val);
+ if (block_size > WARP_SIZE) {
+ if (warp_id == 0) {
+ buf_iw[lane_id] = -INFINITY;
+ }
+ __syncthreads();
+
+ if (lane_id == 0) {
+ buf_iw[warp_id] = max_val;
+ }
+ __syncthreads();
+
+ max_val = buf_iw[lane_id];
+ max_val = warp_reduce_max(max_val);
+ }
+
+ float tmp = 0.0f; // partial sum
+
+#pragma unroll
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
+ const int col = col0 + tid;
+
+ if (ncols_template == 0 && col >= ncols) {
+ break;
+ }
+
+ const float val = expf(vals[col] - max_val);
+ tmp += val;
+ vals[col] = val;
+ }
+
+ // find the sum of exps in the block
+ tmp = warp_reduce_sum(tmp);
+ if (block_size > WARP_SIZE) {
+ __syncthreads();
+ if (warp_id == 0) {
+ buf_iw[lane_id] = 0.0f;
+ }
+ __syncthreads();
+
+ if (lane_id == 0) {
+ buf_iw[warp_id] = tmp;
+ }
+ __syncthreads();
+
+ tmp = buf_iw[lane_id];
+ tmp = warp_reduce_sum(tmp);
+ }
+
+ const float inv_sum = 1.0f / tmp;
+
+#pragma unroll
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
+ const int col = col0 + tid;
+
+ if (ncols_template == 0 && col >= ncols) {
+ return;
+ }
+
+ const int64_t idst = (int64_t)rowx*ncols + col;
+ dst[idst] = vals[col] * inv_sum;
+ }
+}
+
+template<typename T>
+static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
+ int nth = WARP_SIZE;
+ while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
+ const dim3 block_dims(nth, 1, 1);
+ const dim3 block_nums(nrows_x, 1, 1);
+ const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
+ static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
+
+ const uint32_t n_head = nrows_x/nrows_y;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ // FIXME: this limit could be raised by ~2-4x on Ampere or newer
+ if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
+ switch (ncols_x) {
+ case 32:
+ soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ break;
+ case 64:
+ soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ break;
+ case 128:
+ soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ break;
+ case 256:
+ soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ break;
+ case 512:
+ soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ break;
+ case 1024:
+ soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ break;
+ case 2048:
+ soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ break;
+ case 4096:
+ soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ break;
+ default:
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ break;
+ }
+ } else {
+ const size_t shmem_low = WARP_SIZE*sizeof(float);
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+ }
+}
+
+void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ const float * src0_d = (const float *)src0->data;
+ const void * src1_d = src1 ? (const void *)src1->data : nullptr;
+
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t nrows_x = ggml_nrows(src0);
+ const int64_t nrows_y = src0->ne[1];
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+
+ if (use_f16) {
+ const half * src1_dd = (const half *)src1_d;
+
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+ } else {
+ const float * src1_dd = (const float *)src1_d;
+
+ soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+ }
+}
diff --git a/ggml/src/ggml-cuda/softmax.cuh b/ggml/src/ggml-cuda/softmax.cuh
new file mode 100644
index 00000000..4ef4ff86
--- /dev/null
+++ b/ggml/src/ggml-cuda/softmax.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
+
+void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu
new file mode 100644
index 00000000..82e8e875
--- /dev/null
+++ b/ggml/src/ggml-cuda/sumrows.cu
@@ -0,0 +1,40 @@
+#include "sumrows.cuh"
+
+static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
+ const int row = blockIdx.x;
+ const int col = threadIdx.x;
+
+ float sum = 0.0f;
+ for (int i = col; i < ncols; i += blockDim.x) {
+ sum += x[row * ncols + i];
+ }
+
+ sum = warp_reduce_sum(sum);
+
+ if (col == 0) {
+ dst[row] = sum;
+ }
+}
+
+static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ const dim3 block_nums(nrows, 1, 1);
+ k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+}
+
+void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
+}
diff --git a/ggml/src/ggml-cuda/sumrows.cuh b/ggml/src/ggml-cuda/sumrows.cuh
new file mode 100644
index 00000000..e7545f83
--- /dev/null
+++ b/ggml/src/ggml-cuda/sumrows.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu
new file mode 100644
index 00000000..6696a238
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu
new file mode 100644
index 00000000..dd070db2
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu
new file mode 100644
index 00000000..54dcde6f
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu
new file mode 100644
index 00000000..4ec22f79
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu
new file mode 100644
index 00000000..3c15bf7f
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu
new file mode 100644
index 00000000..7e61b5fd
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu
new file mode 100644
index 00000000..fdb15b58
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu
new file mode 100644
index 00000000..0f7c417d
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu
new file mode 100644
index 00000000..851f33c4
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu
new file mode 100644
index 00000000..763809cb
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu
new file mode 100644
index 00000000..f2a276e5
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu
new file mode 100644
index 00000000..cb227f6f
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu
new file mode 100644
index 00000000..97ac0520
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu
new file mode 100644
index 00000000..c772b426
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu
new file mode 100644
index 00000000..5cb74308
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu
new file mode 100644
index 00000000..98a709d1
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu
new file mode 100644
index 00000000..4f2f947a
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu
new file mode 100644
index 00000000..11f96b6f
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu
new file mode 100644
index 00000000..b39bdc06
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu
new file mode 100644
index 00000000..bbd6a2c7
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu
new file mode 100644
index 00000000..9d84ff2b
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu
new file mode 100644
index 00000000..bc8a5bff
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu
new file mode 100644
index 00000000..a679100c
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu
new file mode 100644
index 00000000..8f21bccf
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu
new file mode 100644
index 00000000..858b00fd
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu
new file mode 100644
index 00000000..0fc8011f
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu
new file mode 100644
index 00000000..261fdf62
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu
new file mode 100644
index 00000000..0fb82473
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu
new file mode 100644
index 00000000..a9d9d089
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu
new file mode 100644
index 00000000..7d7b2792
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu
new file mode 100644
index 00000000..a092ee2d
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu
new file mode 100644
index 00000000..db55927a
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu
new file mode 100644
index 00000000..c3c21cef
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu
new file mode 100644
index 00000000..35dd9f52
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu
new file mode 100644
index 00000000..050c22ac
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu
new file mode 100644
index 00000000..de4866c5
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu
new file mode 100644
index 00000000..57a10bc4
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu
new file mode 100644
index 00000000..e0f08b46
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu
new file mode 100644
index 00000000..1c8e8a46
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu
new file mode 100644
index 00000000..cefed83f
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu
new file mode 100644
index 00000000..aede6e35
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu
new file mode 100644
index 00000000..1a1a92c7
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu
new file mode 100644
index 00000000..ad667473
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu
new file mode 100644
index 00000000..c499f455
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu
new file mode 100644
index 00000000..8286ebf3
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu
new file mode 100644
index 00000000..45878688
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu
new file mode 100644
index 00000000..d89103ce
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu
new file mode 100644
index 00000000..bb75fd42
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu
new file mode 100644
index 00000000..b1629817
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu
new file mode 100644
index 00000000..d8657604
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu
new file mode 100644
index 00000000..2e5bd2f1
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu
new file mode 100644
index 00000000..be5f302d
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu
new file mode 100644
index 00000000..8dd91cd7
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu
new file mode 100644
index 00000000..4cb79150
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu
new file mode 100644
index 00000000..09dea426
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu
new file mode 100644
index 00000000..0fbb6076
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu
new file mode 100644
index 00000000..2aeab83b
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu
new file mode 100644
index 00000000..599415b4
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu
new file mode 100644
index 00000000..e4f8e308
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu
new file mode 100644
index 00000000..34d16652
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu
new file mode 100644
index 00000000..4bebef45
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu
new file mode 100644
index 00000000..326468da
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu
new file mode 100644
index 00000000..511b58f4
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu
new file mode 100644
index 00000000..d9906d14
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu
new file mode 100644
index 00000000..f61c183a
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu
new file mode 100644
index 00000000..c10450fd
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu
new file mode 100644
index 00000000..2d5cb195
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu
new file mode 100644
index 00000000..b384f34d
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu
new file mode 100644
index 00000000..446e293b
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu
new file mode 100644
index 00000000..6f430298
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu
new file mode 100644
index 00000000..1cd8ba88
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu
new file mode 100644
index 00000000..1ee2eab6
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu
new file mode 100644
index 00000000..2bc77816
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu
new file mode 100644
index 00000000..d55ced08
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu
new file mode 100644
index 00000000..8361e99c
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu
new file mode 100644
index 00000000..7507a67c
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu
new file mode 100644
index 00000000..61f050b2
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu
new file mode 100644
index 00000000..d4a49d9c
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu
new file mode 100644
index 00000000..d1462789
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu
new file mode 100644
index 00000000..e73f917a
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu
new file mode 100644
index 00000000..d40825df
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu
new file mode 100644
index 00000000..b5c6869f
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu
new file mode 100644
index 00000000..4e21b0cc
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu
new file mode 100644
index 00000000..2eac321b
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu
new file mode 100644
index 00000000..f7d2c3b4
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu
new file mode 100644
index 00000000..a013f400
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu
new file mode 100644
index 00000000..2d94e65c
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-wmma-f16.cuh"
+
+DECL_FATTN_WMMA_F16_CASE(64, 16, float);
+DECL_FATTN_WMMA_F16_CASE(80, 16, float);
+DECL_FATTN_WMMA_F16_CASE(96, 16, float);
+DECL_FATTN_WMMA_F16_CASE(112, 16, float);
+DECL_FATTN_WMMA_F16_CASE(128, 16, float);
+DECL_FATTN_WMMA_F16_CASE(256, 16, float);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu
new file mode 100644
index 00000000..c3d9df3c
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu
@@ -0,0 +1,9 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-wmma-f16.cuh"
+
+DECL_FATTN_WMMA_F16_CASE(64, 32, float);
+DECL_FATTN_WMMA_F16_CASE(80, 32, float);
+DECL_FATTN_WMMA_F16_CASE(96, 32, float);
+DECL_FATTN_WMMA_F16_CASE(112, 32, float);
+DECL_FATTN_WMMA_F16_CASE(128, 32, float);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu
new file mode 100644
index 00000000..bb680e40
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-wmma-f16.cuh"
+
+DECL_FATTN_WMMA_F16_CASE(64, 16, half);
+DECL_FATTN_WMMA_F16_CASE(80, 16, half);
+DECL_FATTN_WMMA_F16_CASE(96, 16, half);
+DECL_FATTN_WMMA_F16_CASE(112, 16, half);
+DECL_FATTN_WMMA_F16_CASE(128, 16, half);
+DECL_FATTN_WMMA_F16_CASE(256, 16, half);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu
new file mode 100644
index 00000000..073f71b1
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-wmma-f16.cuh"
+
+DECL_FATTN_WMMA_F16_CASE(64, 32, half);
+DECL_FATTN_WMMA_F16_CASE(80, 32, half);
+DECL_FATTN_WMMA_F16_CASE(96, 32, half);
+DECL_FATTN_WMMA_F16_CASE(112, 32, half);
+DECL_FATTN_WMMA_F16_CASE(128, 32, half);
+DECL_FATTN_WMMA_F16_CASE(256, 32, half);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu
new file mode 100644
index 00000000..d30710c5
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu
@@ -0,0 +1,8 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-wmma-f16.cuh"
+
+DECL_FATTN_WMMA_F16_CASE(64, 8, half);
+DECL_FATTN_WMMA_F16_CASE(96, 8, half);
+DECL_FATTN_WMMA_F16_CASE(128, 8, half);
+DECL_FATTN_WMMA_F16_CASE(256, 8, half);
diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
new file mode 100755
index 00000000..d7874e6e
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env python3
+
+from glob import glob
+import os
+
+TYPES_KV = ["GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_F16"]
+
+SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f{vkq_size}.cuh"
+
+DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v});
+"""
+
+SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-wmma-f16.cuh"
+
+"""
+
+SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n"
+
+TYPES_MMQ = [
+ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
+ "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
+ "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
+ "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
+]
+
+SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE({type});
+"""
+
+
+def get_short_name(long_quant_name):
+ return long_quant_name.replace("GGML_TYPE_", "").lower()
+
+
+def get_head_sizes(type_k, type_v):
+ if type_k == "GGML_TYPE_F16" and type_v == "GGML_TYPE_F16":
+ return [64, 128, 256]
+ if type_k == "GGML_TYPE_F16":
+ return [64, 128]
+ return [128]
+
+
+for filename in glob("*.cu"):
+ os.remove(filename)
+
+for vkq_size in [16, 32]:
+ for type_k in TYPES_KV:
+ for type_v in TYPES_KV:
+ for head_size in get_head_sizes(type_k, type_v):
+ with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
+ f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
+
+for kq_acc_t in ["half", "float"]:
+ for cols_per_block in [8, 16, 32]:
+ if kq_acc_t == "float" and cols_per_block == 8:
+ continue
+
+ with open(f"fattn-wmma-f16-instance-kq{kq_acc_t}-cpb{cols_per_block}.cu", "w") as f:
+ f.write(SOURCE_FATTN_WMMA_START)
+
+ for head_size in [64, 80, 96, 112, 128, 256]:
+ if cols_per_block == 8 and head_size % 32 != 0: # wmma fragment is 8x32
+ continue
+ if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance
+ continue
+ f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size))
+
+for type in TYPES_MMQ:
+ with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
+ f.write(SOURCE_MMQ.format(type=type))
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu
new file mode 100644
index 00000000..84ec8502
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu
new file mode 100644
index 00000000..583c4e5a
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu
new file mode 100644
index 00000000..edaf1560
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu
new file mode 100644
index 00000000..233d9342
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu
new file mode 100644
index 00000000..6092dc71
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu
new file mode 100644
index 00000000..1d5bd201
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu
new file mode 100644
index 00000000..eb02fab0
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu
new file mode 100644
index 00000000..1eb3b743
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu
new file mode 100644
index 00000000..6415369d
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q2_K);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu
new file mode 100644
index 00000000..ffb6213a
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q3_K);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu
new file mode 100644
index 00000000..0c0b0c8a
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu
new file mode 100644
index 00000000..ee67f694
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu
new file mode 100644
index 00000000..9eeb3cd7
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q4_K);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu
new file mode 100644
index 00000000..cc57fb97
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu
new file mode 100644
index 00000000..721ac790
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu
new file mode 100644
index 00000000..a2e90ffd
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q5_K);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu
new file mode 100644
index 00000000..470938fe
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q6_K);
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu
new file mode 100644
index 00000000..974477bb
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-cuda/tsembd.cu b/ggml/src/ggml-cuda/tsembd.cu
new file mode 100644
index 00000000..153ddbcd
--- /dev/null
+++ b/ggml/src/ggml-cuda/tsembd.cu
@@ -0,0 +1,47 @@
+#include "tsembd.cuh"
+
+static __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) {
+ // blockIDx.y: idx of timesteps->ne[0]
+ // blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE
+ int i = blockIdx.y;
+ int j = threadIdx.x + blockIdx.x * blockDim.x;
+ float * embed_data = (float *)((char *)dst + i*nb1);
+
+ if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
+ embed_data[dim] = 0.f;
+ }
+
+ int half = dim / 2;
+ if (j >= half) {
+ return;
+ }
+
+ float timestep = timesteps[i];
+ float freq = (float)expf(-logf(max_period) * j / half);
+ float arg = timestep * freq;
+ embed_data[j] = cosf(arg);
+ embed_data[j + half] = sinf(arg);
+}
+
+static void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1,
+ const int dim, const int max_period, cudaStream_t stream) {
+ int half_ceil = (dim + 1) / 2;
+ int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE;
+ dim3 gridDim(num_blocks, ne00, 1);
+ timestep_embedding_f32<<<gridDim, CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE, 0, stream>>>(x, dst, nb1, dim, max_period);
+}
+
+void ggml_cuda_op_timestep_embedding(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const int dim = dst->op_params[0];
+ const int max_period = dst->op_params[1];
+
+ timestep_embedding_f32_cuda(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
+}
diff --git a/ggml/src/ggml-cuda/tsembd.cuh b/ggml/src/ggml-cuda/tsembd.cuh
new file mode 100644
index 00000000..84340e3d
--- /dev/null
+++ b/ggml/src/ggml-cuda/tsembd.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
+
+void ggml_cuda_op_timestep_embedding(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu
new file mode 100644
index 00000000..f9e20801
--- /dev/null
+++ b/ggml/src/ggml-cuda/unary.cu
@@ -0,0 +1,314 @@
+#include "unary.cuh"
+
+static __global__ void gelu_f32(const float * x, float * dst, const int k) {
+ const float GELU_COEF_A = 0.044715f;
+ const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ float xi = x[i];
+ dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
+}
+
+static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
+ const float GELU_QUICK_COEF = -1.702f;
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+ if (i >= k) {
+ return;
+ }
+ dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
+}
+
+static __global__ void silu_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = x[i] / (1.0f + expf(-x[i]));
+}
+
+static __global__ void tanh_f32(const float * x, float * dst, int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+ if (i >= k) {
+ return;
+ }
+ dst[i] = tanhf(x[i]);
+}
+
+static __global__ void relu_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = fmaxf(x[i], 0);
+}
+
+static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = 1.0f / (1.0f + expf(-x[i]));
+}
+
+static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
+}
+
+static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
+}
+
+static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+ if (i >= k) {
+ return;
+ }
+ dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
+}
+
+static __global__ void sqr_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = x[i] * x[i];
+}
+
+static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = sqrtf(x[i]);
+}
+
+static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
+ gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
+ gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
+ silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
+ tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
+ relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
+ sigmoid_f32<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
+ hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
+ hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
+ leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
+}
+
+static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
+ sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE;
+ sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ gelu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ gelu_quick_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ tanh_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ hardsigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ float negative_slope;
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
+
+ leaky_relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), negative_slope, stream);
+}
+
+void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh
new file mode 100644
index 00000000..4cfb0479
--- /dev/null
+++ b/ggml/src/ggml-cuda/unary.cuh
@@ -0,0 +1,33 @@
+#include "common.cuh"
+
+#define CUDA_GELU_BLOCK_SIZE 256
+#define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_TANH_BLOCK_SIZE 256
+#define CUDA_RELU_BLOCK_SIZE 256
+#define CUDA_SIGMOID_BLOCK_SIZE 256
+#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
+#define CUDA_HARDSWISH_BLOCK_SIZE 256
+#define CUDA_SQR_BLOCK_SIZE 256
+#define CUDA_SQRT_BLOCK_SIZE 256
+
+void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/upscale.cu b/ggml/src/ggml-cuda/upscale.cu
new file mode 100644
index 00000000..cf513c3a
--- /dev/null
+++ b/ggml/src/ggml-cuda/upscale.cu
@@ -0,0 +1,51 @@
+#include "upscale.cuh"
+
+static __global__ void upscale_f32(const float * x, float * dst,
+ const int nb00, const int nb01, const int nb02, const int nb03,
+ const int ne10, const int ne11, const int ne12, const int ne13,
+ const float sf0, const float sf1, const float sf2, const float sf3) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index >= ne10 * ne11 * ne12 * ne13) {
+ return;
+ }
+
+ int i10 = index % ne10;
+ int i11 = (index / ne10) % ne11;
+ int i12 = (index / (ne10 * ne11)) % ne12;
+ int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
+
+ int i00 = i10 / sf0;
+ int i01 = i11 / sf1;
+ int i02 = i12 / sf2;
+ int i03 = i13 / sf3;
+
+ dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
+}
+
+static void upscale_f32_cuda(const float * x, float * dst,
+ const int nb00, const int nb01, const int nb02, const int nb03,
+ const int ne10, const int ne11, const int ne12, const int ne13,
+ const float sf0, const float sf1, const float sf2, const float sf3,
+ cudaStream_t stream) {
+ int dst_size = ne10 * ne11 * ne12 * ne13;
+ int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
+
+ upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
+}
+
+void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const float sf0 = (float)dst->ne[0]/src0->ne[0];
+ const float sf1 = (float)dst->ne[1]/src0->ne[1];
+ const float sf2 = (float)dst->ne[2]/src0->ne[2];
+ const float sf3 = (float)dst->ne[3]/src0->ne[3];
+
+ upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
+}
diff --git a/ggml/src/ggml-cuda/upscale.cuh b/ggml/src/ggml-cuda/upscale.cuh
new file mode 100644
index 00000000..d4d76523
--- /dev/null
+++ b/ggml/src/ggml-cuda/upscale.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_UPSCALE_BLOCK_SIZE 256
+
+void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh
new file mode 100644
index 00000000..1248eacd
--- /dev/null
+++ b/ggml/src/ggml-cuda/vecdotq.cuh
@@ -0,0 +1,1229 @@
+//
+// Copyright (C) 2023-2024 The ggml authors
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+
+#include "common.cuh"
+#include <cstdint>
+
+static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
+ const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
+
+ int x32 = x16[2*i32 + 0] << 0;
+ x32 |= x16[2*i32 + 1] << 16;
+
+ return x32;
+}
+
+static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) {
+ return ((const int *) x)[i32]; // assume at least 4 byte alignment
+}
+
+// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
+// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
+
+#define VDR_Q4_0_Q8_1_MMVQ 2
+#define VDR_Q4_0_Q8_1_MMQ 4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
+ const int * v, const int * u, const float & d4, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+ const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
+ sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
+ }
+
+ const float2 ds8f = __half22float2(ds8);
+
+ // second part effectively subtracts 8 from each quant value
+ return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
+}
+
+#define VDR_Q4_1_Q8_1_MMVQ 2
+#define VDR_Q4_1_Q8_1_MMQ 4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(
+ const int * v, const int * u, const half2 & dm4, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+ const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
+ sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
+ }
+
+#ifdef GGML_CUDA_F16
+ const float2 tmp = __half22float2(__hmul2(dm4, ds8));
+ const float d4d8 = tmp.x;
+ const float m4s8 = tmp.y;
+#else
+ const float2 dm4f = __half22float2(dm4);
+ const float2 ds8f = __half22float2(ds8);
+ const float d4d8 = dm4f.x * ds8f.x;
+ const float m4s8 = dm4f.y * ds8f.y;
+#endif // GGML_CUDA_F16
+
+ // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
+ return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
+}
+
+#define VDR_Q5_0_Q8_1_MMVQ 2
+#define VDR_Q5_0_Q8_1_MMQ 4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
+ const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+ vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
+ vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+ vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+ vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+ sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
+
+ int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+ vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
+ vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
+ vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
+ vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
+ sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
+ }
+
+ const float2 ds8f = __half22float2(ds8);
+
+ // second part effectively subtracts 16 from each quant value
+ return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
+}
+
+#define VDR_Q5_1_Q8_1_MMVQ 2
+#define VDR_Q5_1_Q8_1_MMQ 4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(
+ const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+ vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
+ vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+ vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+ vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+ sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
+
+ int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+ vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
+ vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
+ vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
+ vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
+ sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
+ }
+
+#ifdef GGML_CUDA_F16
+ const float2 tmp = __half22float2(__hmul2(dm5, ds8));
+ const float d5d8 = tmp.x;
+ const float m5s8 = tmp.y;
+#else
+ const float2 dm5f = __half22float2(dm5);
+ const float2 ds8f = __half22float2(ds8);
+ const float d5d8 = dm5f.x * ds8f.x;
+ const float m5s8 = dm5f.y * ds8f.y;
+#endif // GGML_CUDA_F16
+
+ // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
+ return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
+}
+
+#define VDR_Q8_0_Q8_1_MMVQ 2
+#define VDR_Q8_0_Q8_1_MMQ 8
+
+template <typename T, int vdr> static __device__ __forceinline__ T vec_dot_q8_0_q8_1_impl(
+ const int * v, const int * u, const T & d8_0, const T & d8_1) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
+ }
+
+ return d8_0*d8_1 * ((T) sumi);
+}
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl(
+ const int * v, const int * u, const half2 & dm8, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
+ }
+
+#ifdef GGML_CUDA_F16
+ const float2 tmp = __half22float2(__hmul2(dm8, ds8));
+ const float d8d8 = tmp.x;
+ const float m8s8 = tmp.y;
+#else
+ const float2 dm8f = __half22float2(dm8);
+ const float2 ds8f = __half22float2(ds8);
+ const float d8d8 = dm8f.x * ds8f.x;
+ const float m8s8 = dm8f.y * ds8f.y;
+#endif // GGML_CUDA_F16
+
+ // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
+ return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
+}
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_impl(
+ const int * v, const int * u, const float * d8_0, const float & d8_1) {
+
+ float sumf = 0.0f;
+
+#pragma unroll
+ for (int i0 = 0; i0 < vdr; i0 += QI8_0/2) {
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = i0; i < i0 + QI8_0/2; ++i) {
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
+ }
+
+ sumf += d8_0[i0/(QI8_0/2)]*sumi;
+ }
+
+ return d8_1*sumf;
+}
+
+#define VDR_Q2_K_Q8_1_MMVQ 1
+#define VDR_Q2_K_Q8_1_MMQ 4
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
+ const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+ const half2 & dm2, const float * __restrict__ d8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR2_K; ++i) {
+ const int sc = scales[2*i];
+
+ const int vi = (v >> (2*i)) & 0x03030303;
+
+ sumf_d += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
+
+ // fill int with 4x m
+ int m = sc >> 4;
+ m |= m << 8;
+ m |= m << 16;
+ sumf_m += d8[i] * ggml_cuda_dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
+ }
+
+ const float2 dm2f = __half22float2(dm2);
+
+ return dm2f.x*sumf_d - dm2f.y*sumf_m;
+}
+
+// contiguous v/x + u/y values
+template <int ns8>
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
+ const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
+
+ float sumf = 0.0f;
+ float sumf_d8 = 0.0f;
+
+#pragma unroll
+ for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
+ const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
+ int sumi_d0 = 0;
+
+ const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
+ int sumi_d1 = 0;
+
+#pragma unroll
+ for (int i = i0; i < i0 + QI8_1/2; ++i) {
+ sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);
+ }
+ sumf_d8 += dm2f0.x * sumi_d0;
+
+#pragma unroll
+ for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
+ sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);
+ }
+ sumf_d8 += dm2f1.x * sumi_d1;
+
+ if (i0/QI8_1 < ns8) {
+ const float2 s8f = __half22float2(s8[i0/QI8_1]);
+ sumf -= dm2f0.y*s8f.x;
+ sumf -= dm2f1.y*s8f.y;
+ } else {
+ int sumi_m0 = 0;
+#pragma unroll
+ for (int i = i0; i < i0 + QI8_1/2; ++i) {
+ sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);
+ }
+ sumf_d8 -= dm2f0.y * sumi_m0;
+
+ int sumi_m1 = 0;
+#pragma unroll
+ for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
+ sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);
+ }
+ sumf_d8 -= dm2f1.y * sumi_m1;
+ }
+ }
+
+ return sumf + d8*sumf_d8;
+}
+
+#define VDR_Q3_K_Q8_1_MMVQ 1
+#define VDR_Q3_K_Q8_1_MMQ 2
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
+ const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+ const int & scale_offset, const float & d3, const float * __restrict__ d8) {
+
+ float sumf = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR3_K; ++i) {
+ const int isc = scale_offset + 2*i;
+
+ const int isc_low = isc % (QK_K/32);
+ const int sc_shift_low = 4 * (isc / (QK_K/32));
+ const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF;
+
+ const int isc_high = isc % (QK_K/64);
+ const int sc_shift_high = 2 * (isc / (QK_K/64));
+ const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
+
+ const int sc = (sc_low | sc_high) - 32;
+
+ const int vil = (vl >> (2*i)) & 0x03030303;
+
+ const int vih = ((vh >> i) << 2) & 0x04040404;
+
+ const int vi = __vsubss4(vil, vih);
+
+ sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
+ }
+
+ return d3 * sumf;
+}
+
+// contiguous v/x + u/y values
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
+ const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
+ const float & d3, const float & d8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
+ int sumi_sc = 0;
+
+#pragma unroll
+ for (int i = i0; i < i0 + QI8_1/2; ++i) {
+ sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
+ }
+
+ sumi += sumi_sc * scales[i0 / (QI8_1/2)];
+ }
+
+ return d3*d8 * sumi;
+}
+
+#define VDR_Q4_K_Q8_1_MMVQ 2
+#define VDR_Q4_K_Q8_1_MMQ 8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
+ const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+ const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR4_K; ++i) {
+ const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
+ const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
+
+ const int dot1 = ggml_cuda_dp4a(v1i, u[2*i+1], ggml_cuda_dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
+ const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+1], ggml_cuda_dp4a(0x01010101, u[2*i+0], 0)); // sum of u
+
+ sumf_d += d8[i] * (dot1 * sc[i]);
+ sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
+ }
+
+ const float2 dm4f = __half22float2(dm4);
+
+ return dm4f.x*sumf_d - dm4f.y*sumf_m;
+}
+
+// contiguous v/x + u/y values
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
+ const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+ const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
+ int sumi_d = 0;
+
+#pragma unroll
+ for (int j = 0; j < QI8_1; ++j) {
+ sumi_d = ggml_cuda_dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
+ }
+
+ const float2 ds8f = __half22float2(ds8[i]);
+
+ sumf_d += ds8f.x * (sc[i] * sumi_d);
+ sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
+ }
+
+ const float2 dm4f = __half22float2(dm4);
+
+ return dm4f.x*sumf_d - dm4f.y*sumf_m;
+}
+
+#define VDR_Q5_K_Q8_1_MMVQ 2
+#define VDR_Q5_K_Q8_1_MMQ 8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
+ const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+ const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR5_K; ++i) {
+ const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
+ const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
+
+ const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
+ const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
+
+ const int v0i = vl0i | vh0i;
+ const int v1i = vl1i | vh1i;
+
+ const int dot1 = ggml_cuda_dp4a(v0i, u[2*i+0], ggml_cuda_dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
+ const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+0], ggml_cuda_dp4a(0x01010101, u[2*i+1], 0)); // sum of u
+
+ sumf_d += d8[i] * (dot1 * sc[i]);
+ sumf_m += d8[i] * (dot2 * m[i]);
+
+ }
+
+ const float2 dm5f = __half22float2(dm5);
+
+ return dm5f.x*sumf_d - dm5f.y*sumf_m;
+}
+
+// contiguous v/x + u/y values
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
+ const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+ const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
+ int sumi_d = 0;
+
+#pragma unroll
+ for (int j = 0; j < QI8_1; ++j) {
+ sumi_d = ggml_cuda_dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
+ }
+
+ const float2 ds8f = __half22float2(ds8[i]);
+
+ sumf_d += ds8f.x * (sc[i] * sumi_d);
+ sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
+ }
+
+ const float2 dm4f = __half22float2(dm4);
+
+ return dm4f.x*sumf_d - dm4f.y*sumf_m;
+}
+
+#define VDR_Q6_K_Q8_1_MMVQ 1
+#define VDR_Q6_K_Q8_1_MMQ 8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
+ const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales,
+ const float & d, const float * __restrict__ d8) {
+
+ float sumf = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR6_K; ++i) {
+ const int sc = scales[4*i];
+
+ const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+
+ const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
+
+ const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
+
+ sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
+ }
+
+ return d*sumf;
+}
+
+// contiguous v/x + u/y values
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
+ const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
+ const float & d6, const float * __restrict__ d8) {
+
+ float sumf_d = 0.0f;
+
+ const int sc_packed = get_int_b4(sc, 0);
+ const int8_t * sc_reg = (const int8_t *) &sc_packed;
+
+#pragma unroll
+ for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
+ int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
+
+#pragma unroll
+ for (int i = i0; i < i0 + 2; ++i) {
+ sumi_d.x = ggml_cuda_dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product
+ sumi_d.x = ggml_cuda_dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product
+
+ sumi_d.y = ggml_cuda_dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product
+ sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
+ }
+
+ sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
+ }
+
+ return d6 * sumf_d;
+}
+
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq + kbx;
+
+ int v[VDR_Q4_0_Q8_1_MMVQ];
+ int u[2*VDR_Q4_0_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
+ v[i] = get_int_b2(bq4_0->qs, iqs + i);
+ u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_0);
+ }
+
+ return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
+}
+
+
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq + kbx;
+
+ int v[VDR_Q4_1_Q8_1_MMVQ];
+ int u[2*VDR_Q4_1_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
+ v[i] = get_int_b4(bq4_1->qs, iqs + i);
+ u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_1);
+ }
+
+ return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq + kbx;
+
+ int vl[VDR_Q5_0_Q8_1_MMVQ];
+ int vh[VDR_Q5_0_Q8_1_MMVQ];
+ int u[2*VDR_Q5_0_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {
+ vl[i] = get_int_b2(bq5_0->qs, iqs + i);
+ vh[i] = get_int_b2(bq5_0->qh, 0) >> (4 * (iqs + i));
+ u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_0);
+ }
+
+ return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq + kbx;
+
+ int vl[VDR_Q5_1_Q8_1_MMVQ];
+ int vh[VDR_Q5_1_Q8_1_MMVQ];
+ int u[2*VDR_Q5_1_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
+ vl[i] = get_int_b4(bq5_1->qs, iqs + i);
+ vh[i] = get_int_b4(bq5_1->qh, 0) >> (4 * (iqs + i));
+ u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_1);
+ }
+
+ return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
+}
+
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq + kbx;
+
+ int v[VDR_Q8_0_Q8_1_MMVQ];
+ int u[VDR_Q8_0_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
+ v[i] = get_int_b2(bq8_0->qs, iqs + i);
+ u[i] = get_int_b4(bq8_1->qs, iqs + i);
+ }
+
+ return vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
+}
+
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q2_K * bq2_K = (const block_q2_K *) vbq + kbx;
+
+ const int bq8_offset = QR2_K * (iqs / QI8_1);
+ const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+ const uint8_t * scales = bq2_K->scales + scale_offset;
+
+ const int v = get_int_b4(bq2_K->qs, iqs);
+ int u[QR2_K];
+ float d8[QR2_K];
+
+#pragma unroll
+ for (int i = 0; i < QR2_K; ++ i) {
+ u[i] = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+ d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
+ }
+
+ return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
+}
+
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q3_K * bq3_K = (const block_q3_K *) vbq + kbx;
+
+ const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
+ const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+ const float d = bq3_K->d;
+
+ const int vl = get_int_b2(bq3_K->qs, iqs);
+
+ // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+ const int vh = ~get_int_b2(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;
+
+ int u[QR3_K];
+ float d8[QR3_K];
+
+#pragma unroll
+ for (int i = 0; i < QR3_K; ++i) {
+ u[i] = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+ d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
+ }
+
+ return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
+}
+
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx;
+
+ int v[2];
+ int u[2*QR4_K];
+ float d8[QR4_K];
+
+ // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
+ const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
+
+ // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
+ // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
+ // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
+ // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
+
+ const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+ v[0] = q4[0];
+ v[1] = q4[4];
+
+ const uint16_t * scales = (const uint16_t *)bq4_K->scales;
+ uint16_t aux[2];
+ const int j = bq8_offset/2;
+ if (j < 2) {
+ aux[0] = scales[j+0] & 0x3f3f;
+ aux[1] = scales[j+2] & 0x3f3f;
+ } else {
+ aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+ aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+ }
+ const uint8_t * sc = (const uint8_t *)aux;
+ const uint8_t * m = sc + 2;
+
+ for (int i = 0; i < QR4_K; ++i) {
+ const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+ d8[i] = __low2float(bq8i->ds);
+
+ const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+ u[2*i+0] = q8[0];
+ u[2*i+1] = q8[4];
+ }
+
+ return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q5_K * bq5_K = (const block_q5_K *) vbq + kbx;
+
+ int vl[2];
+ int vh[2];
+ int u[2*QR5_K];
+ float d8[QR5_K];
+
+ const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));
+ const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+ const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));
+
+ vl[0] = ql[0];
+ vl[1] = ql[4];
+
+ vh[0] = qh[0] >> bq8_offset;
+ vh[1] = qh[4] >> bq8_offset;
+
+ const uint16_t * scales = (const uint16_t *)bq5_K->scales;
+ uint16_t aux[2];
+ const int j = bq8_offset/2;
+ if (j < 2) {
+ aux[0] = scales[j+0] & 0x3f3f;
+ aux[1] = scales[j+2] & 0x3f3f;
+ } else {
+ aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+ aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+ }
+ const uint8_t * sc = (const uint8_t *)aux;
+ const uint8_t * m = sc + 2;
+
+#pragma unroll
+ for (int i = 0; i < QR5_K; ++i) {
+ const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+ d8[i] = __low2float(bq8i->ds);
+
+ const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+ u[2*i+0] = q8[0];
+ u[2*i+1] = q8[4];
+ }
+
+ return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
+}
+
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q6_K * bq6_K = (const block_q6_K *) vbq + kbx;
+
+ const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
+ const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
+ const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
+
+ const int vl = get_int_b2(bq6_K->ql, iqs);
+ const int vh = get_int_b2(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;
+
+ const int8_t * scales = bq6_K->scales + scale_offset;
+
+ int u[QR6_K];
+ float d8[QR6_K];
+
+#pragma unroll
+ for (int i = 0; i < QR6_K; ++i) {
+ u[i] = get_int_b4(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
+ d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds);
+ }
+
+ return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
+}
+
+#define VDR_IQ2_XXS_Q8_1_MMVQ 2
+#define VDR_IQ2_XXS_Q8_1_MMQ 2
+
+static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq + kbx;
+
+ const int q2 = get_int_b2(bq2->qs, iqs);
+ const uint8_t * aux8 = (const uint8_t *) &q2;
+ const uint32_t aux32 = get_int_b2(bq2->qs, iqs + 1);
+
+ int sumi = 0;
+#pragma unroll
+ for (int k0 = 0; k0 < 8; k0 += 2) {
+ const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]);
+ const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F];
+
+ const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
+ const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
+ const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
+ sumi = ggml_cuda_dp4a(grid0, u0, sumi);
+
+ const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
+ const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
+ const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
+ sumi = ggml_cuda_dp4a(grid1, u1, sumi);
+ }
+
+ const int ls = aux32 >> 28;
+ sumi = (ls*sumi + sumi/2)/4;
+ const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
+ return d * sumi;
+}
+
+#define VDR_IQ2_XS_Q8_1_MMVQ 2
+#define VDR_IQ2_XS_Q8_1_MMQ 2
+
+static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq + kbx;
+
+ const int2 q2_packed = make_int2(get_int_b2(bq2->qs, iqs + 0), get_int_b2(bq2->qs, iqs + 1));
+ const uint16_t * q2 = (const uint16_t *) &q2_packed;
+ const int ls0 = bq2->scales[iqs/2] & 0x0F;
+ const int ls1 = bq2->scales[iqs/2] >> 4;
+
+ int sumi0 = 0;
+ int sumi1 = 0;
+#pragma unroll
+ for (int l0 = 0; l0 < 8; l0 += 2) {
+ const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF));
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9));
+
+ const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
+ const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
+
+ const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+ const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
+
+ if (l0 < 4) {
+ sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0);
+ sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0);
+ } else {
+ sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1);
+ sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1);
+ }
+ }
+ const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4;
+ const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
+ return d * sumi;
+}
+
+#define VDR_IQ2_S_Q8_1_MMVQ 2
+#define VDR_IQ2_S_Q8_1_MMQ 2
+
+static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq2_s * bq2 = (const block_iq2_s *) vbq + kbx;
+
+ const int qs_packed = get_int_b2(bq2->qs, iqs/2);
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ const int qh = bq2->qh[iqs/2];
+
+ const int signs_packed_32 = get_int_b2(bq2->qs, QK_K/32 + iqs/2);
+ const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
+
+ const int ls0 = bq2->scales[iqs/2] & 0x0F;
+ const int ls1 = bq2->scales[iqs/2] >> 4;
+
+ int sumi0 = 0;
+ int sumi1 = 0;
+#pragma unroll
+ for (int l0 = 0; l0 < 8; l0 += 2) {
+ const int * grid_pos = (const int *)(iq2s_grid + (qs[l0/2] | ((qh << (8-l0)) & 0x300)));
+
+ const int signs0 = __vcmpne4(((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), 0x00000000);
+ const int signs1 = __vcmpne4(((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), 0x00000000);
+
+ const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
+ const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
+
+ const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+ const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
+
+ if (l0 < 4) {
+ sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0);
+ sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0);
+ } else {
+ sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1);
+ sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1);
+ }
+ }
+ const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4;
+
+ const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
+ return d * sumi;
+}
+
+#define VDR_IQ3_XXS_Q8_1_MMVQ 2
+#define VDR_IQ3_XXS_Q8_1_MMQ 2
+
+static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq3_xxs * bq3 = (const block_iq3_xxs *) vbq + kbx;
+
+ const int2 q3_packed = make_int2(get_int_b2(bq3->qs, iqs), get_int_b2(bq3->qs, iqs+1));
+ const uint8_t * q3 = (const uint8_t *) &q3_packed;
+ const uint32_t aux32 = get_int_b2(bq3->qs, QK_K/16 + iqs/2);
+
+ int sumi = 0;
+#pragma unroll
+ for (int l0 = 0; l0 < 8; l0 += 2) {
+ const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);
+
+ const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F));
+
+ const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
+ const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
+
+ const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+ const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
+
+ sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
+ sumi = ggml_cuda_dp4a(grid_h, u1, sumi);
+ }
+
+ const int ls = aux32 >> 28;
+ sumi = (ls*sumi + sumi/2)/2;
+ const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds);
+ return d * sumi;
+}
+
+#define VDR_IQ3_S_Q8_1_MMVQ 2
+#define VDR_IQ3_S_Q8_1_MMQ 2
+
+// TODO: don't use lookup table for signs
+static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq3_s * bq3 = (const block_iq3_s *) vbq + kbx;
+
+ const int2 qs_packed = make_int2(get_int_b2(bq3->qs, iqs + 0), get_int_b2(bq3->qs, iqs + 1));
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ const int qh = bq3->qh[iqs/2];
+
+ const int signs_packed_32 = get_int_b2(bq3->signs, iqs/2);
+ const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
+
+ int sumi = 0;
+#pragma unroll
+ for (int l0 = 0; l0 < 8; l0 += 2) {
+ const int2 grid_pos = make_int2(
+ iq3s_grid[qs[l0 + 0] | ((qh << (8 - l0)) & 0x100)],
+ iq3s_grid[qs[l0 + 1] | ((qh << (7 - l0)) & 0x100)]);
+
+ const int signs0 = __vcmpne4(((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), 0x00000000);
+ const int signs1 = __vcmpne4(((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), 0x00000000);
+
+ const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
+ const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
+
+ const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+ const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
+
+ sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
+ sumi = ggml_cuda_dp4a(grid_h, u1, sumi);
+ }
+
+ sumi *= 1 + 2*((bq3->scales[iqs/4] >> ((iqs << 1) & 0x04)) & 0x0F);
+
+ const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds);
+ return d * sumi;
+}
+
+#define VDR_IQ1_S_Q8_1_MMVQ 1
+#define VDR_IQ1_S_Q8_1_MMQ 1
+
+static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+ const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;
+
+ const int qs_packed = get_int_b2(bq1->qs, iqs);
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ const int qh = bq1->qh[iqs];
+
+ int sumi = 0;
+#pragma unroll
+ for (int l0 = 0; l0 < 8; l0 += 2) {
+ const int grid = iq1s_grid_gpu[qs[l0/2] | (((qh >> 3*(l0/2)) & 0x07) << 8)];
+
+ const int grid0 = (grid >> 0) & 0x0F0F0F0F;
+ const int grid1 = (grid >> 4) & 0x0F0F0F0F;
+
+ const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
+ const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
+
+ sumi = ggml_cuda_dp4a(grid0, u0, sumi);
+ sumi = ggml_cuda_dp4a(grid1, u1, sumi);
+ }
+
+ const float d1q = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1);
+ const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
+ const float2 ds = __half22float2(bq8_1[iqs].ds);
+ return d1q * (ds.x*sumi + ds.y*delta);
+}
+
+#define VDR_IQ1_M_Q8_1_MMVQ 1
+#define VDR_IQ1_M_Q8_1_MMQ 1
+
+static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq1_m * bq1 = (const block_iq1_m *) vbq + kbx;
+
+ const int qs_packed = get_int_b4(bq1->qs, iqs);
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ int sumi[2] = {0};
+ float sumf[2] = {0.0f};
+#pragma unroll
+ for (int l0 = 0; l0 < 8; l0 += 2) {
+ const int qhl = bq1->qh[2*iqs + l0/4] >> (4 * ((l0/2) % 2));
+
+ const int grid = iq1s_grid_gpu[qs[l0/2] | ((qhl & 0x07) << 8)];
+
+ const int grid0 = (grid >> 0) & 0x0F0F0F0F;
+ const int grid1 = (grid >> 4) & 0x0F0F0F0F;
+
+ const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
+ const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
+
+ sumi[l0/4] = ggml_cuda_dp4a(grid0, u0, sumi[l0/4]);
+ sumi[l0/4] = ggml_cuda_dp4a(grid1, u1, sumi[l0/4]);
+
+ const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f*IQ1M_DELTA/0x08);
+ int sumy = 0;
+ sumy = ggml_cuda_dp4a(u0, 0x01010101, sumy);
+ sumy = ggml_cuda_dp4a(u1, 0x01010101, sumy);
+ sumf[l0/4] += delta*sumy;
+ }
+
+ const uint16_t * sc = (const uint16_t *) bq1->scales;
+
+ iq1m_scale_t scale;
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00F0) | ((sc[2] >> 4) & 0x0F00) | (sc[3] & 0xF000);
+ const float d = __half2float(scale.f16) * __low2float(bq8_1[iqs].ds);
+
+ const int tmp = sc[iqs/2] >> (6*(iqs%2));
+ const int sc0 = 2*((tmp >> 0) & 0x07) + 1;
+ const int sc1 = 2*((tmp >> 3) & 0x07) + 1;
+ return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
+}
+
+static __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+ const block_iq1_bn * bq1 = (const block_iq1_bn *) vbq + kbx;
+
+ static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
+
+ // iqs is 0 or 1
+
+ int sumi = 0;
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+ const int * q8 = (const int *)bq8_1[iqs].qs;
+ int val[4];
+ for (int l = 0; l < 2; ++l) {
+ int8_t * a = (int8_t *)val;
+ const int i16 = 2*iqs + l;
+ for (int k = 0; k < 3; ++k) {
+ uint8_t q = bq1->ql[3*i16+k];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ *a++ = vs-1;
+ }
+ }
+ uint8_t v = k_mult[i16]*bq1->extra;
+ int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ *a++ = vs-1;
+ sumi = __dp4a(val[0], q8[4*l+0], __dp4a(val[1], q8[4*l+1], __dp4a(val[2], q8[4*l+2], __dp4a(val[3], q8[4*l+3], sumi))));
+ }
+#else
+ const int8_t * q8 = bq8_1[iqs].qs;
+ for (int l = 0; l < 2; ++l) {
+ const int i16 = 2*iqs + l;
+ for (int k = 0; k < 3; ++k) {
+ uint8_t q = bq1->ql[3*i16+k];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = (v + (v >> 1)) >> 7;
+ sumi += q8[j]*(vs - 1);
+ }
+ q8 += 5;
+ }
+ uint8_t v = k_mult[i16]*bq1->extra;
+ int8_t vs = (v + (v >> 1)) >> 7;
+ sumi += q8[0]*(vs - 1);
+ q8++;
+ }
+#endif
+ return __low2float(bq8_1[iqs].ds) * sumi;
+}
+
+static __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+ const block_iq2_bn * bq2 = (const block_iq2_bn *) vbq + kbx;
+
+ // iqs is 0 or 1
+
+#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+ auto qs = (const uint16_t *)bq2->qs + 4*iqs;
+ auto q8l = (const int *)bq8_1[0].qs + 2*iqs;
+ auto q8h = (const int *)bq8_1[1].qs + 2*iqs;
+ int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
+ for (int j = 0; j < 2; ++j) {
+ int vl = qs[2*j+0] | (uint32_t(qs[2*j+1]) << 16);
+ int vh = vl >> 4;
+ sumi1 = __dp4a(vl & 0x03030303, q8l[j+0], sumi1);
+ sumi2 = __dp4a(vl & 0x0c0c0c0c, q8l[j+4], sumi2);
+ sumi3 = __dp4a(vh & 0x03030303, q8h[j+0], sumi3);
+ sumi4 = __dp4a(vh & 0x0c0c0c0c, q8h[j+4], sumi4);
+ }
+ auto d8l = __half22float2(bq8_1[0].ds);
+ auto d8h = __half22float2(bq8_1[1].ds);
+ return d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y;
+#else
+ int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
+ auto q8l = bq8_1[0].qs + 8*iqs;
+ auto q8h = bq8_1[1].qs + 8*iqs;
+ auto qs = bq2->qs + 8*iqs;
+ for (int j = 0; j < 8; ++j) {
+ sumi1 += q8l[j+ 0] * (qs[j] & 0x03);
+ sumi2 += q8l[j+16] * (qs[j] & 0x0c);
+ sumi3 += q8h[j+ 0] * (qs[j] & 0x30);
+ sumi4 += q8h[j+16] * (qs[j] & 0xc0);
+ }
+ auto d8l = __half22float2(bq8_1[0].ds);
+ auto d8h = __half22float2(bq8_1[1].ds);
+ return d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y;
+#endif
+}
+
+static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
+ const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
+ const int8_t * q0_8 = (const int8_t *) &q0_32;
+ const char4 val0_8 = make_char4(
+ kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
+
+ const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
+ const int8_t * q1_8 = (const int8_t *) &q1_32;
+ const char4 val1_8 = make_char4(
+ kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
+
+ return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
+}
+
+#define VDR_IQ4_NL_Q8_1_MMVQ 2
+#define VDR_IQ4_NL_Q8_1_MMQ 4
+
+static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq4_nl * bq4 = (const block_iq4_nl *) vbq + kbx;
+
+ const int * q8 = (const int *) bq8_1->qs + iqs;
+
+ int sumi = 0;
+#pragma unroll
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
+ const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
+ const int2 v = get_int_from_table_16(aux_q4);
+
+ sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
+ sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
+ }
+
+ const float d = __half2float(bq4->d) * __low2float(bq8_1->ds);
+ return d * sumi;
+}
+
+#define VDR_IQ4_XS_Q8_1_MMVQ 4
+#define VDR_IQ4_XS_Q8_1_MMQ 4
+
+static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq + kbx;
+
+ int sumi = 0;
+#pragma unroll
+ for (int j = 0; j < 4; ++j) {
+ const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
+ const int2 v = get_int_from_table_16(aux_q4);
+
+ const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
+ const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
+
+ sumi = ggml_cuda_dp4a(v.x, u0, sumi);
+ sumi = ggml_cuda_dp4a(v.y, u1, sumi);
+ }
+
+ const int ls = ((bq4->scales_l[iqs/8] >> (iqs & 0x04)) & 0x0F) | (((bq4->scales_h >> (iqs/2)) & 0x03) << 4);
+ sumi *= ls - 32;
+
+ const float d = __half2float(bq4->d) * __low2float(bq8_1[iqs/4].ds);
+ return d * sumi;
+}
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
new file mode 100644
index 00000000..a2c8dbec
--- /dev/null
+++ b/ggml/src/ggml-impl.h
@@ -0,0 +1,655 @@
+#pragma once
+
+#include "ggml.h"
+
+// GGML internal header
+
+#include <assert.h>
+#include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
+#include <stddef.h>
+#include <stdbool.h>
+#include <string.h> // memcpy
+#include <math.h> // fabsf
+
+#undef MIN
+#undef MAX
+
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+#if defined(_MSC_VER)
+
+#define m512bh(p) p
+#define m512i(p) p
+
+#else
+
+#define m512bh(p) (__m512bh)(p)
+#define m512i(p) (__m512i)(p)
+
+#endif
+
+/**
+ * Converts brain16 to float32.
+ *
+ * The bfloat16 floating point format has the following structure:
+ *
+ * ┌sign
+ * │
+ * │ ┌exponent
+ * │ │
+ * │ │ ┌mantissa
+ * │ │ │
+ * │┌──┴───┐┌─┴───┐
+ * 0b0000000000000000 brain16
+ *
+ * Since bf16 has the same number of exponent bits as a 32bit float,
+ * encoding and decoding numbers becomes relatively straightforward.
+ *
+ * ┌sign
+ * │
+ * │ ┌exponent
+ * │ │
+ * │ │ ┌mantissa
+ * │ │ │
+ * │┌──┴───┐┌─┴───────────────────┐
+ * 0b00000000000000000000000000000000 IEEE binary32
+ *
+ * For comparison, the standard fp16 format has fewer exponent bits.
+ *
+ * ┌sign
+ * │
+ * │ ┌exponent
+ * │ │
+ * │ │ ┌mantissa
+ * │ │ │
+ * │┌─┴─┐┌─┴──────┐
+ * 0b0000000000000000 IEEE binary16
+ *
+ * @see IEEE 754-2008
+ */
+static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
+ union {
+ float f;
+ uint32_t i;
+ } u;
+ u.i = (uint32_t)h.bits << 16;
+ return u.f;
+}
+
+/**
+ * Converts float32 to brain16.
+ *
+ * This function is binary identical to AMD Zen4 VCVTNEPS2BF16.
+ * Subnormals shall be flushed to zero, and NANs will be quiet.
+ * This code should vectorize nicely if using modern compilers.
+ */
+static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
+ ggml_bf16_t h;
+ union {
+ float f;
+ uint32_t i;
+ } u;
+ u.f = s;
+ if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
+ h.bits = (u.i >> 16) | 64; /* force to quiet */
+ return h;
+ }
+ if (!(u.i & 0x7f800000)) { /* subnormal */
+ h.bits = (u.i & 0x80000000) >> 16; /* flush to zero */
+ return h;
+ }
+ h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
+ return h;
+}
+
+#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
+#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// static_assert should be a #define, but if it's not,
+// fall back to the _Static_assert C11 keyword.
+// if C99 - static_assert is noop
+// ref: https://stackoverflow.com/a/53923785/4039976
+#ifndef __cplusplus
+#ifndef static_assert
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
+#define static_assert(cond, msg) _Static_assert(cond, msg)
+#else
+#define static_assert(cond, msg) struct global_scope_noop_trick
+#endif
+#endif
+#endif
+
+// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
+#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
+#ifndef __FMA__
+#define __FMA__
+#endif
+#ifndef __F16C__
+#define __F16C__
+#endif
+#endif
+
+// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
+#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
+#ifndef __SSE3__
+#define __SSE3__
+#endif
+#ifndef __SSSE3__
+#define __SSSE3__
+#endif
+#endif
+
+#if defined(__ARM_FEATURE_SVE)
+#include <arm_sve.h>
+#endif
+
+// 16-bit float
+// on Arm, we use __fp16
+// on x86, we use uint16_t
+#if defined(__ARM_NEON)
+
+// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
+//
+// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
+//
+#include <arm_neon.h>
+
+#ifdef _MSC_VER
+
+typedef uint16_t ggml_fp16_internal_t;
+
+#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
+
+#else
+
+typedef __fp16 ggml_fp16_internal_t;
+
+#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
+
+#endif // _MSC_VER
+
+#if !defined(__aarch64__)
+
+// 32-bit ARM compatibility
+
+// vaddvq_s16
+// vpaddq_s16
+// vpaddq_s32
+// vaddvq_s32
+// vaddvq_f32
+// vmaxvq_f32
+// vcvtnq_s32_f32
+// vzip1_u8
+// vzip2_u8
+
+inline static int32_t vaddvq_s16(int16x8_t v) {
+ return
+ (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
+ (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
+ (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
+ (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
+}
+
+inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
+ int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
+ int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
+ return vcombine_s16(a0, b0);
+}
+
+inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
+ int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
+ int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
+ return vcombine_s32(a0, b0);
+}
+
+inline static int32_t vaddvq_s32(int32x4_t v) {
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
+}
+
+inline static float vaddvq_f32(float32x4_t v) {
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
+}
+
+inline static float vmaxvq_f32(float32x4_t v) {
+ return
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
+}
+
+inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
+ int32x4_t res;
+
+ res[0] = roundf(vgetq_lane_f32(v, 0));
+ res[1] = roundf(vgetq_lane_f32(v, 1));
+ res[2] = roundf(vgetq_lane_f32(v, 2));
+ res[3] = roundf(vgetq_lane_f32(v, 3));
+
+ return res;
+}
+
+inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
+ uint8x8_t res;
+
+ res[0] = a[0]; res[1] = b[0];
+ res[2] = a[1]; res[3] = b[1];
+ res[4] = a[2]; res[5] = b[2];
+ res[6] = a[3]; res[7] = b[3];
+
+ return res;
+}
+
+inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
+ uint8x8_t res;
+
+ res[0] = a[4]; res[1] = b[4];
+ res[2] = a[5]; res[3] = b[5];
+ res[4] = a[6]; res[5] = b[6];
+ res[6] = a[7]; res[7] = b[7];
+
+ return res;
+}
+
+// vld1q_s16_x2
+// vld1q_u8_x2
+// vld1q_u8_x4
+// vld1q_s8_x2
+// vld1q_s8_x4
+// TODO: double-check these work correctly
+
+typedef struct ggml_int16x8x2_t {
+ int16x8_t val[2];
+} ggml_int16x8x2_t;
+
+inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
+ ggml_int16x8x2_t res;
+
+ res.val[0] = vld1q_s16(ptr + 0);
+ res.val[1] = vld1q_s16(ptr + 8);
+
+ return res;
+}
+
+typedef struct ggml_uint8x16x2_t {
+ uint8x16_t val[2];
+} ggml_uint8x16x2_t;
+
+inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
+ ggml_uint8x16x2_t res;
+
+ res.val[0] = vld1q_u8(ptr + 0);
+ res.val[1] = vld1q_u8(ptr + 16);
+
+ return res;
+}
+
+typedef struct ggml_uint8x16x4_t {
+ uint8x16_t val[4];
+} ggml_uint8x16x4_t;
+
+inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
+ ggml_uint8x16x4_t res;
+
+ res.val[0] = vld1q_u8(ptr + 0);
+ res.val[1] = vld1q_u8(ptr + 16);
+ res.val[2] = vld1q_u8(ptr + 32);
+ res.val[3] = vld1q_u8(ptr + 48);
+
+ return res;
+}
+
+typedef struct ggml_int8x16x2_t {
+ int8x16_t val[2];
+} ggml_int8x16x2_t;
+
+inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
+ ggml_int8x16x2_t res;
+
+ res.val[0] = vld1q_s8(ptr + 0);
+ res.val[1] = vld1q_s8(ptr + 16);
+
+ return res;
+}
+
+typedef struct ggml_int8x16x4_t {
+ int8x16_t val[4];
+} ggml_int8x16x4_t;
+
+inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
+ ggml_int8x16x4_t res;
+
+ res.val[0] = vld1q_s8(ptr + 0);
+ res.val[1] = vld1q_s8(ptr + 16);
+ res.val[2] = vld1q_s8(ptr + 32);
+ res.val[3] = vld1q_s8(ptr + 48);
+
+ return res;
+}
+
+// NOTE: not tested
+inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
+ int8x16_t res;
+
+ res[ 0] = a[b[ 0]];
+ res[ 1] = a[b[ 1]];
+ res[ 2] = a[b[ 2]];
+ res[ 3] = a[b[ 3]];
+ res[ 4] = a[b[ 4]];
+ res[ 5] = a[b[ 5]];
+ res[ 6] = a[b[ 6]];
+ res[ 7] = a[b[ 7]];
+ res[ 8] = a[b[ 8]];
+ res[ 9] = a[b[ 9]];
+ res[10] = a[b[10]];
+ res[11] = a[b[11]];
+ res[12] = a[b[12]];
+ res[13] = a[b[13]];
+ res[14] = a[b[14]];
+ res[15] = a[b[15]];
+
+ return res;
+}
+
+// NOTE: not tested
+inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
+ uint8x16_t res;
+
+ res[ 0] = a[b[ 0]];
+ res[ 1] = a[b[ 1]];
+ res[ 2] = a[b[ 2]];
+ res[ 3] = a[b[ 3]];
+ res[ 4] = a[b[ 4]];
+ res[ 5] = a[b[ 5]];
+ res[ 6] = a[b[ 6]];
+ res[ 7] = a[b[ 7]];
+ res[ 8] = a[b[ 8]];
+ res[ 9] = a[b[ 9]];
+ res[10] = a[b[10]];
+ res[11] = a[b[11]];
+ res[12] = a[b[12]];
+ res[13] = a[b[13]];
+ res[14] = a[b[14]];
+ res[15] = a[b[15]];
+
+ return res;
+}
+
+#else
+
+#define ggml_int16x8x2_t int16x8x2_t
+#define ggml_uint8x16x2_t uint8x16x2_t
+#define ggml_uint8x16x4_t uint8x16x4_t
+#define ggml_int8x16x2_t int8x16x2_t
+#define ggml_int8x16x4_t int8x16x4_t
+
+#define ggml_vld1q_s16_x2 vld1q_s16_x2
+#define ggml_vld1q_u8_x2 vld1q_u8_x2
+#define ggml_vld1q_u8_x4 vld1q_u8_x4
+#define ggml_vld1q_s8_x2 vld1q_s8_x2
+#define ggml_vld1q_s8_x4 vld1q_s8_x4
+#define ggml_vqtbl1q_s8 vqtbl1q_s8
+#define ggml_vqtbl1q_u8 vqtbl1q_u8
+
+#endif // !defined(__aarch64__)
+
+#if !defined(__ARM_FEATURE_DOTPROD)
+
+inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
+ const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
+ const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
+
+ return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
+}
+
+#else
+
+#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
+
+#endif // !defined(__ARM_FEATURE_DOTPROD)
+
+#endif // defined(__ARM_NEON)
+
+#if defined(__ARM_NEON) && !defined(_MSC_VER)
+
+#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
+
+#define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+
+static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
+ ggml_fp16_internal_t tmp;
+ memcpy(&tmp, &h, sizeof(ggml_fp16_t));
+ return (float)tmp;
+}
+
+static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
+ ggml_fp16_t res;
+ ggml_fp16_internal_t tmp = f;
+ memcpy(&res, &tmp, sizeof(ggml_fp16_t));
+ return res;
+}
+
+#else
+
+#ifdef __wasm_simd128__
+#include <wasm_simd128.h>
+#else
+#ifdef __POWER9_VECTOR__
+#include <altivec.h>
+#undef bool
+#define bool _Bool
+#else
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#include <intrin.h>
+#else
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
+#if !defined(__riscv)
+#include <immintrin.h>
+#endif
+#endif
+#endif
+#endif
+#endif
+
+#ifdef __riscv_v_intrinsic
+#include <riscv_vector.h>
+#endif
+
+#if defined(__loongarch64)
+#if defined(__loongarch_asx)
+#include <lasxintrin.h>
+#endif
+#if defined(__loongarch_sx)
+#include <lsxintrin.h>
+#endif
+#endif
+
+#if defined(__loongarch_asx)
+
+typedef union {
+ int32_t i;
+ float f;
+} ft_union;
+
+/* float type data load instructions */
+static __m128 __lsx_vreplfr2vr_s(float val) {
+ ft_union fi_tmpval = {.f = val};
+ return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
+}
+
+static __m256 __lasx_xvreplfr2vr_s(float val) {
+ ft_union fi_tmpval = {.f = val};
+ return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
+}
+#endif
+
+#ifdef __F16C__
+
+#ifdef _MSC_VER
+#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
+#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
+#else
+#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
+#endif
+
+#elif defined(__POWER9_VECTOR__)
+
+#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
+/* the inline asm below is about 12% faster than the lookup method */
+#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
+#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
+
+static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
+ register float f;
+ register double d;
+ __asm__(
+ "mtfprd %0,%2\n"
+ "xscvhpdp %0,%0\n"
+ "frsp %1,%0\n" :
+ /* temp */ "=d"(d),
+ /* out */ "=f"(f):
+ /* in */ "r"(h));
+ return f;
+}
+
+static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
+ register double d;
+ register ggml_fp16_t r;
+ __asm__( /* xscvdphp can work on double or single precision */
+ "xscvdphp %0,%2\n"
+ "mffprd %1,%0\n" :
+ /* temp */ "=d"(d),
+ /* out */ "=r"(r):
+ /* in */ "f"(f));
+ return r;
+}
+
+#else
+
+// FP16 <-> FP32
+// ref: https://github.com/Maratyszcza/FP16
+
+static inline float fp32_from_bits(uint32_t w) {
+ union {
+ uint32_t as_bits;
+ float as_value;
+ } fp32;
+ fp32.as_bits = w;
+ return fp32.as_value;
+}
+
+static inline uint32_t fp32_to_bits(float f) {
+ union {
+ float as_value;
+ uint32_t as_bits;
+ } fp32;
+ fp32.as_value = f;
+ return fp32.as_bits;
+}
+
+static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
+ const uint32_t w = (uint32_t) h << 16;
+ const uint32_t sign = w & UINT32_C(0x80000000);
+ const uint32_t two_w = w + w;
+
+ const uint32_t exp_offset = UINT32_C(0xE0) << 23;
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+ const float exp_scale = 0x1.0p-112f;
+#else
+ const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
+#endif
+ const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
+
+ const uint32_t magic_mask = UINT32_C(126) << 23;
+ const float magic_bias = 0.5f;
+ const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
+
+ const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
+ const uint32_t result = sign |
+ (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
+ return fp32_from_bits(result);
+}
+
+static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
+ const float scale_to_inf = 0x1.0p+112f;
+ const float scale_to_zero = 0x1.0p-110f;
+#else
+ const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
+ const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
+#endif
+ float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
+
+ const uint32_t w = fp32_to_bits(f);
+ const uint32_t shl1_w = w + w;
+ const uint32_t sign = w & UINT32_C(0x80000000);
+ uint32_t bias = shl1_w & UINT32_C(0xFF000000);
+ if (bias < UINT32_C(0x71000000)) {
+ bias = UINT32_C(0x71000000);
+ }
+
+ base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
+ const uint32_t bits = fp32_to_bits(base);
+ const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
+ const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
+ const uint32_t nonsign = exp_bits + mantissa_bits;
+ return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
+}
+
+#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
+
+#endif // __F16C__
+
+#endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
+
+#ifdef __ARM_FEATURE_SVE
+#include <arm_sve.h>
+#endif // __ARM_FEATURE_SVE
+
+// precomputed f32 table for f16 (256 KB)
+// defined in ggml.c, initialized in ggml_init()
+extern float ggml_table_f32_f16[1 << 16];
+
+// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
+// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
+// This is also true for POWER9.
+#if !defined(GGML_FP16_TO_FP32)
+inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
+ uint16_t s;
+ memcpy(&s, &f, sizeof(uint16_t));
+ return ggml_table_f32_f16[s];
+}
+
+#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
+#endif
+
+#if !defined(GGML_FP32_TO_FP16)
+#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
+#endif
+
+#define GGML_HASHTABLE_FULL ((size_t)-1)
+#define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2)
+
+struct ggml_hash_set ggml_hash_set_new(size_t size);
+
+bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
+
+// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
+size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
+
+// returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
+size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key);
+
+// return index, asserts if table is full
+size_t ggml_hash_find_or_insert( struct ggml_hash_set hash_set, struct ggml_tensor * key);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/src/ggml-kompute.cpp b/ggml/src/ggml-kompute.cpp
new file mode 100644
index 00000000..ed5f2e34
--- /dev/null
+++ b/ggml/src/ggml-kompute.cpp
@@ -0,0 +1,2038 @@
+#include "ggml.h"
+#include "ggml-backend.h"
+#include "ggml-backend-impl.h"
+#include "ggml-kompute.h"
+
+// These are generated at build time by cmake custom command
+#include "shaderop_scale.h"
+#include "shaderop_scale_8.h"
+#include "shaderop_add.h"
+#include "shaderop_addrow.h"
+#include "shaderop_mul.h"
+#include "shaderop_silu.h"
+#include "shaderop_relu.h"
+#include "shaderop_gelu.h"
+#include "shaderop_softmax.h"
+#include "shaderop_norm.h"
+#include "shaderop_rmsnorm.h"
+#include "shaderop_diagmask.h"
+#include "shaderop_mul_mat_f16.h"
+#include "shaderop_mul_mat_q8_0.h"
+#include "shaderop_mul_mat_q4_0.h"
+#include "shaderop_mul_mat_q4_1.h"
+#include "shaderop_mul_mat_q6_k.h"
+#include "shaderop_mul_mat_mat_f32.h"
+#include "shaderop_getrows_f32.h"
+#include "shaderop_getrows_f16.h"
+#include "shaderop_getrows_q4_0.h"
+#include "shaderop_getrows_q4_1.h"
+#include "shaderop_getrows_q6_k.h"
+#include "shaderop_rope_f16.h"
+#include "shaderop_rope_f32.h"
+#include "shaderop_cpy_f16_f16.h"
+#include "shaderop_cpy_f16_f32.h"
+#include "shaderop_cpy_f32_f16.h"
+#include "shaderop_cpy_f32_f32.h"
+
+#include <algorithm>
+#include <array>
+#include <cassert>
+#include <cstdint>
+#include <cstdio>
+#include <cstring>
+#include <iostream>
+#include <memory>
+#include <stdexcept>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include <kompute/Kompute.hpp>
+#include <vulkan/vulkan.hpp>
+
+#ifdef __linux__
+#include <cstdlib> // for setenv
+#endif
+
+#define QK4_0 32
+#define QR4_0 2
+#define QK4_1 32
+#define QK_NL 16
+
+typedef ggml_fp16_t half;
+
+static std::string ggml_kompute_format_name(int device) {
+ return "Kompute" + std::to_string(device);
+}
+
+struct ggml_kompute_context {
+ int device;
+ std::string name;
+ std::shared_ptr<vk::DescriptorPool> pool;
+
+ ggml_kompute_context(int device)
+ : device(device), name(ggml_kompute_format_name(device)) {}
+};
+
+// FIXME: It would be good to consolidate the kompute manager and the kompute context into one object
+// and consolidate the init functions and simplify object lifetime management. As it currently stands,
+// we *have* to have the kompute manager no matter what for device discovery, but the kompute context
+// is only created when a device is set and vulkan is explicitly turned on.
+static ggml_kompute_context *s_kompute_context = nullptr;
+
+class kompute_manager {
+ kp::Manager *s_mgr = nullptr;
+
+public:
+ kp::Manager *operator()() {
+ if (s_mgr && !s_mgr->hasInstance()) {
+ destroy();
+ }
+ if (!s_mgr) {
+ s_mgr = new kp::Manager;
+ }
+ return s_mgr;
+ }
+
+ void destroy() {
+ delete s_mgr;
+ s_mgr = nullptr;
+ }
+};
+
+static kompute_manager komputeManager;
+
+struct ggml_vk_memory {
+ void *data = nullptr;
+ size_t size = 0;
+ vk::DeviceMemory *primaryMemory = nullptr;
+ vk::Buffer *primaryBuffer = nullptr;
+ vk::DeviceMemory *stagingMemory = nullptr;
+ vk::Buffer *stagingBuffer = nullptr;
+};
+
+#ifdef __linux__
+__attribute__((constructor))
+static void enable_sam() {
+ setenv("RADV_PERFTEST", "sam", false);
+}
+#endif
+
+static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_device) {
+ vk::PhysicalDeviceFeatures availableFeatures;
+ physical_device.getFeatures(&availableFeatures);
+
+ if (!availableFeatures.shaderInt16)
+ return false;
+
+ vk::PhysicalDeviceVulkan11Features availableFeatures11;
+ vk::PhysicalDeviceVulkan12Features availableFeatures12;
+
+ availableFeatures11.pNext = &availableFeatures12;
+ availableFeatures12.pNext = nullptr;
+
+ vk::PhysicalDeviceFeatures2 features2;
+ features2.pNext = &availableFeatures11;
+
+ physical_device.getFeatures2(&features2);
+
+ if (!availableFeatures11.uniformAndStorageBuffer16BitAccess ||
+ !availableFeatures11.storageBuffer16BitAccess) {
+ return false;
+ }
+
+ if (!availableFeatures12.storageBuffer8BitAccess ||
+ !availableFeatures12.uniformAndStorageBuffer8BitAccess ||
+ !availableFeatures12.shaderFloat16 ||
+ !availableFeatures12.shaderInt8) {
+ return false;
+ }
+
+ return true;
+}
+
+static const char * ggml_vk_getVendorName(uint32_t vendorID) {
+ switch (vendorID) {
+ case 0x10DE:
+ return "nvidia";
+ case 0x1002:
+ return "amd";
+ case 0x8086:
+ return "intel";
+ default:
+ return "unknown";
+ }
+}
+
+static std::vector<ggml_vk_device> ggml_vk_available_devices_internal(size_t memoryRequired) {
+ std::vector<ggml_vk_device> results;
+ if (!komputeManager()->hasVulkan() || !komputeManager()->hasInstance())
+ return results;
+
+ std::vector<vk::PhysicalDevice> physical_devices;
+ try {
+ physical_devices = komputeManager()->listDevices();
+ } catch (vk::SystemError & err) {
+ std::cerr << __func__ << ": ignoring Vulkan exception: " << err.what() << "\n";
+ return results;
+ }
+
+ uint32_t deviceCount = physical_devices.size();
+ if (deviceCount == 0)
+ return results;
+
+ std::unordered_map<std::string, size_t> count_by_name;
+
+ for (uint32_t i = 0; i < deviceCount; i++) {
+ const auto & physical_device = physical_devices[i];
+
+ VkPhysicalDeviceProperties dev_props = physical_device.getProperties();
+ VkPhysicalDeviceMemoryProperties memoryProperties = physical_device.getMemoryProperties();
+ const uint32_t major = VK_VERSION_MAJOR(dev_props.apiVersion);
+ const uint32_t minor = VK_VERSION_MINOR(dev_props.apiVersion);
+ if (major < 1 || minor < 2)
+ continue;
+
+ if (!ggml_vk_checkPhysicalDeviceFeatures(physical_device))
+ continue;
+
+ size_t heapSize = 0;
+ for (uint32_t j = 0; j < memoryProperties.memoryHeapCount; ++j) {
+ VkMemoryHeap heap = memoryProperties.memoryHeaps[j];
+ if (heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) {
+ heapSize = heap.size;
+ break;
+ }
+ }
+
+ if (heapSize < memoryRequired)
+ continue;
+
+ auto ext_props = physical_device.enumerateDeviceExtensionProperties();
+ bool has_maintenance4 = false;
+
+ // Check if maintenance4 is supported
+ for (const auto & properties : ext_props) {
+ if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
+ has_maintenance4 = true;
+ }
+ }
+
+ vk::PhysicalDeviceSubgroupProperties subgroup_props;
+ vk::PhysicalDeviceProperties2 dev_props2;
+ vk::PhysicalDeviceMaintenance3Properties dev_props3;
+ vk::PhysicalDeviceMaintenance4Properties dev_props4;
+ dev_props2.pNext = &dev_props3;
+ dev_props3.pNext = &subgroup_props;
+ if (has_maintenance4) {
+ subgroup_props.pNext = &dev_props4;
+ }
+ physical_device.getProperties2(&dev_props2);
+
+ if (subgroup_props.subgroupSize < 32)
+ continue;
+
+ ggml_vk_device d;
+ d.index = i;
+ d.type = dev_props.deviceType;
+ d.heapSize = heapSize;
+ d.vendor = strdup(ggml_vk_getVendorName(dev_props.vendorID));
+ d.subgroupSize = subgroup_props.subgroupSize;
+ d.bufferAlignment = dev_props.limits.minStorageBufferOffsetAlignment;
+
+ if (has_maintenance4) {
+ d.maxAlloc = std::min(dev_props3.maxMemoryAllocationSize, dev_props4.maxBufferSize);
+ } else {
+ d.maxAlloc = dev_props3.maxMemoryAllocationSize;
+ }
+
+ std::string name(dev_props.deviceName);
+ size_t n_idx = ++count_by_name[name];
+ if (n_idx > 1) {
+ name += " (" + std::to_string(n_idx) + ")";
+ }
+ d.name = strdup(name.c_str());
+
+ results.push_back(d);
+ }
+
+ std::stable_sort(results.begin(), results.end(),
+ [](const ggml_vk_device& lhs, const ggml_vk_device& rhs) -> bool {
+ if (lhs.type != rhs.type) {
+ if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return true;
+ if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return false;
+
+ if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return true;
+ if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return false;
+ }
+ return lhs.heapSize < rhs.heapSize;
+ }
+ );
+
+ return results;
+}
+
+// public API returns a C-style array
+ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count) {
+ auto devices = ggml_vk_available_devices_internal(memoryRequired);
+ *count = devices.size();
+ if (devices.empty()) {
+ return nullptr;
+ }
+
+ size_t nbytes = sizeof (ggml_vk_device) * (devices.size());
+ auto * arr = static_cast<ggml_vk_device *>(malloc(nbytes));
+ memcpy(arr, devices.data(), nbytes);
+ return arr;
+}
+
+static void ggml_vk_filterByVendor(std::vector<ggml_vk_device>& devices, const std::string& targetVendor) {
+ devices.erase(
+ std::remove_if(devices.begin(), devices.end(),
+ [&targetVendor](const ggml_vk_device& device) {
+ return device.vendor != targetVendor;
+ }),
+ devices.end()
+ );
+}
+
+static void ggml_vk_filterByName(std::vector<ggml_vk_device>& devices, const std::string& targetName) {
+ devices.erase(
+ std::remove_if(devices.begin(), devices.end(),
+ [&targetName](const ggml_vk_device& device) {
+ return device.name != targetName;
+ }),
+ devices.end()
+ );
+}
+
+static bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const std::string & name) {
+ if (name.empty())
+ return false;
+
+ auto devices = ggml_vk_available_devices_internal(memoryRequired);
+ if (name == "amd" || name == "nvidia" || name == "intel") {
+ ggml_vk_filterByVendor(devices, name);
+ } else if (name != "gpu") {
+ ggml_vk_filterByName(devices, name);
+ }
+
+ if (devices.empty())
+ return false;
+
+ *device = devices.front();
+ return true;
+}
+
+bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const char * name) {
+ return ggml_vk_get_device(device, memoryRequired, std::string(name));
+}
+
+bool ggml_vk_has_vulkan() {
+ return komputeManager()->hasVulkan();
+}
+
+bool ggml_vk_has_device() {
+ return komputeManager()->hasDevice();
+}
+
+ggml_vk_device ggml_vk_current_device() {
+ if (!komputeManager()->hasDevice())
+ return ggml_vk_device();
+
+ auto devices = ggml_vk_available_devices_internal(0);
+ ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data());
+ GGML_ASSERT(!devices.empty());
+ return devices.front();
+}
+
+static
+void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t size) {
+ std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
+ vk::DescriptorPoolSize(
+ vk::DescriptorType::eStorageBuffer,
+ 3 * size // Descriptor count is number of possible tensors to pass into an algorithm
+ )
+ };
+
+ vk::DescriptorPoolCreateInfo descriptorPoolInfo(
+ vk::DescriptorPoolCreateFlags(),
+ size, // Max sets
+ static_cast<uint32_t>(descriptorPoolSizes.size()),
+ descriptorPoolSizes.data());
+
+ ctx->pool = std::make_shared<vk::DescriptorPool>();
+ vk::Result r = komputeManager()->device()->createDescriptorPool(
+ &descriptorPoolInfo, nullptr, ctx->pool.get());
+ if (r != vk::Result::eSuccess)
+ std::cerr << "Error allocating descriptor pool" << vk::to_string(r);
+}
+
+static
+void ggml_vk_free_descriptor_pool(struct ggml_kompute_context * ctx) {
+ if (ctx->pool) {
+ komputeManager()->device()->destroy(
+ *ctx->pool,
+ (vk::Optional<const vk::AllocationCallbacks>)nullptr);
+ ctx->pool = nullptr;
+ }
+}
+
+static
+vk::Buffer *ggml_vk_allocate_buffer(size_t size) {
+ vk::BufferCreateInfo bufferCreateInfo;
+ bufferCreateInfo.size = size;
+ bufferCreateInfo.usage = vk::BufferUsageFlagBits::eStorageBuffer |
+ vk::BufferUsageFlagBits::eTransferSrc |
+ vk::BufferUsageFlagBits::eTransferDst;
+ bufferCreateInfo.sharingMode = vk::SharingMode::eExclusive;
+
+ vk::Buffer *vkBuffer = new vk::Buffer;
+ vk::Result r = komputeManager()->device()->createBuffer(&bufferCreateInfo, nullptr, vkBuffer);
+ if (r != vk::Result::eSuccess)
+ std::cerr << "Error allocating buffer " << vk::to_string(r) << std::endl;
+ return vkBuffer;
+}
+
+static
+vk::DeviceMemory *ggml_vk_allocate(size_t size, vk::MemoryPropertyFlags flags, vk::MemoryRequirements requirements, bool *isHostVisible) {
+
+ uint32_t memoryTypeIndex = -1;
+ bool memoryTypeIndexFound = false;
+ vk::PhysicalDeviceMemoryProperties memoryProperties = komputeManager()->physicalDevice()->getMemoryProperties();
+ for (uint32_t i = 0; i < memoryProperties.memoryTypeCount; i++) {
+ const vk::MemoryType &memoryType = memoryProperties.memoryTypes[i];
+ const vk::MemoryHeap &memoryHeap = memoryProperties.memoryHeaps[memoryType.heapIndex];
+ if (memoryHeap.size < size) {
+ continue;
+ }
+
+ if (requirements.memoryTypeBits & (1 << i)) {
+ if (((memoryProperties.memoryTypes[i]).propertyFlags &
+ flags) == flags) {
+ memoryTypeIndex = i;
+ memoryTypeIndexFound = true;
+ if (isHostVisible && (memoryProperties.memoryTypes[i].propertyFlags & vk::MemoryPropertyFlagBits::eHostVisible)) {
+ *isHostVisible = true;
+ }
+ break;
+ }
+ }
+ }
+ if (!memoryTypeIndexFound) {
+ throw std::runtime_error(
+ "Memory type index for buffer creation not found");
+ }
+
+ vk::MemoryAllocateInfo allocInfo;
+ allocInfo.allocationSize = size;
+ allocInfo.memoryTypeIndex = memoryTypeIndex;
+ vk::DeviceMemory *vkDeviceMemory = new vk::DeviceMemory;
+ vk::Result r = komputeManager()->device()->allocateMemory(&allocInfo, nullptr, vkDeviceMemory);
+ if (r != vk::Result::eSuccess) {
+ std::cerr << "Error allocating memory " << vk::to_string(r) << std::endl;
+ throw std::runtime_error("Error allocating vulkan memory.");
+ }
+ return vkDeviceMemory;
+}
+
+static size_t ggml_vk_aligned_offset(ggml_backend_buffer_t buffer, size_t offset) {
+ size_t minStorageBufferOffsetAlignment = ggml_backend_buffer_get_alignment(buffer);
+
+ // If offset is already aligned, return it directly
+ if (offset % minStorageBufferOffsetAlignment == 0) {
+ return offset;
+ }
+
+ // Otherwise, return the largest multiple of minStorageBufferOffsetAlignment less than offset
+ return (offset / minStorageBufferOffsetAlignment) * minStorageBufferOffsetAlignment;
+}
+
+static ggml_vk_memory ggml_vk_allocate(size_t size) {
+ ggml_vk_memory memory;
+ bool isHostVisible = false;
+ {
+ memory.primaryBuffer = ggml_vk_allocate_buffer(size);
+ vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.primaryBuffer);
+ vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eDeviceLocal;
+ memory.primaryMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
+ komputeManager()->device()->bindBufferMemory(*memory.primaryBuffer, *memory.primaryMemory, 0);
+ if (isHostVisible) {
+ vk::Result r = komputeManager()->device()->mapMemory(*memory.primaryMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
+ if (r != vk::Result::eSuccess)
+ std::cerr << "Error mapping memory" << vk::to_string(r);
+ }
+ }
+
+ if (!isHostVisible) {
+ memory.stagingBuffer = ggml_vk_allocate_buffer(size);
+ vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.stagingBuffer);
+ vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eHostVisible |
+ vk::MemoryPropertyFlagBits::eHostCoherent |
+ vk::MemoryPropertyFlagBits::eHostCached;
+ memory.stagingMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
+ komputeManager()->device()->bindBufferMemory(*memory.stagingBuffer, *memory.stagingMemory, 0);
+ vk::Result r = komputeManager()->device()->mapMemory(*memory.stagingMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
+ if (r != vk::Result::eSuccess)
+ std::cerr << "Error mapping memory" << vk::to_string(r);
+ }
+
+ memory.size = size;
+ return memory;
+}
+
+static void ggml_vk_free_memory(ggml_vk_memory &memory)
+{
+ komputeManager()->device()->destroy(
+ *memory.primaryBuffer,
+ (vk::Optional<const vk::AllocationCallbacks>)nullptr);
+ if (memory.stagingBuffer) {
+ komputeManager()->device()->destroy(
+ *memory.stagingBuffer,
+ (vk::Optional<const vk::AllocationCallbacks>)nullptr);
+ }
+ komputeManager()->device()->freeMemory(
+ *memory.primaryMemory,
+ (vk::Optional<const vk::AllocationCallbacks>)nullptr);
+ if (memory.stagingMemory) {
+ komputeManager()->device()->freeMemory(
+ *memory.stagingMemory,
+ (vk::Optional<const vk::AllocationCallbacks>)nullptr);
+ }
+}
+
+static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft);
+
+static
+ggml_vk_memory * ggml_vk_find_tensor(const struct ggml_tensor * t, uint64_t & offset) {
+ ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
+
+ // compatibility with ggml-backend
+ GGML_ASSERT(buffer && buffer->buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name);
+
+ ggml_vk_memory * buf_ctx = static_cast<ggml_vk_memory *>(buffer->context);
+
+ const intptr_t ioffs = intptr_t(t->data) - intptr_t(buf_ctx->data);
+
+ GGML_ASSERT(ioffs >= 0 && ioffs + int64_t(ggml_nbytes(t)) <= int64_t(buffer->size));
+
+ offset = uint64_t(ioffs);
+ return buf_ctx;
+}
+
+static
+const std::shared_ptr<kp::Tensor> ggml_vk_get_tensor(const struct ggml_tensor * t, uint32_t * alignedOffset = nullptr) {
+ uint64_t originalOffset = 0;
+ auto * res = ggml_vk_find_tensor(t, originalOffset);
+ if (!res) {
+ static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
+ return nullTensor;
+ }
+
+ // Create a tensor whose memory will be composed of our buffers at the correct offset
+ const size_t nelements = ggml_nelements(t);
+ size_t nbytes = ggml_nbytes(t);
+
+ size_t vulkanOffset = ggml_vk_aligned_offset(t->buffer, originalOffset);
+ if (alignedOffset) {
+ *alignedOffset = originalOffset - vulkanOffset;
+ nbytes += *alignedOffset;
+ }
+
+ return komputeManager()->tensor(
+ t->data,
+ nelements,
+ nbytes, kp::Tensor::TensorDataTypes::eFloat,
+ res->primaryMemory, res->primaryBuffer,
+ res->stagingMemory, res->stagingBuffer,
+ vulkanOffset);
+}
+
+static std::vector<uint32_t> getSpirvShader(const unsigned char* rawData, size_t size) {
+ if (size % sizeof(uint32_t) != 0) {
+ throw std::runtime_error("Invalid size: must be divisible by sizeof(uint32_t)");
+ }
+
+ const uint32_t* data_ptr = reinterpret_cast<const uint32_t*>(rawData);
+ size_t count = size / sizeof(uint32_t);
+ return std::vector<uint32_t>(data_ptr, data_ptr + count);
+}
+
+inline static
+uint32_t safe_divide(uint32_t a, uint32_t b) {
+ if (b <= 1) {
+ return a;
+ }
+ if ((a % b) != 0) {
+ fprintf(stderr, "((%u %% %u) == %u) != 0\n", a, b, a % b);
+ GGML_ASSERT(!"safe_divide result would've had remainder");
+ }
+ return a / b;
+}
+
+static void ggml_vk_add(
+ kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& inA,
+ const std::shared_ptr<kp::Tensor>& inB,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+ int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
+ int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
+ int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
+ int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
+ int32_t ne0,
+ int32_t nb0, int32_t nb1, int32_t nb2, int32_t nb3
+) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_add_comp_spv,
+ kp::shader_data::op_add_comp_spv_len);
+
+ struct PushConstants {
+ uint32_t inAOff, inBOff, outOff;
+ int32_t ne00;
+ int32_t nb00, nb01, nb02, nb03;
+ int32_t ne10, ne11, ne12, ne13;
+ int32_t nb10, nb11, nb12, nb13;
+ int32_t ne0;
+ int32_t nb0, nb1, nb2, nb3;
+ } const pushConsts {
+ safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+ ne00,
+ nb00, nb01, nb02, nb03,
+ ne10, ne11, ne12, ne13,
+ nb10, nb11, nb12, nb13,
+ ne0,
+ nb0, nb1, nb2, nb3
+ };
+
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(__func__)) {
+ s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
+ } else {
+ s_algo = komputeManager()->getAlgorithm(__func__);
+ s_algo->setTensors({inA, inB, out});
+ s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_addrow(kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& inA,
+ const std::shared_ptr<kp::Tensor>& inB,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+ uint32_t size, uint32_t row = 0) {
+
+ const static auto spirv = getSpirvShader(kp::shader_data::op_addrow_comp_spv,
+ kp::shader_data::op_addrow_comp_spv_len);
+
+ struct PushConstants {
+ uint32_t inAOff, inBOff, outOff;
+ uint32_t row;
+ } const pushConsts {
+ safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+ row
+ };
+
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(__func__))
+ s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
+ else {
+ s_algo = komputeManager()->getAlgorithm(__func__);
+ s_algo->setTensors({inA, inB, out});
+ s_algo->setWorkgroup({size});
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_mul(
+ kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& inA,
+ const std::shared_ptr<kp::Tensor>& inB,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+ int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
+ int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
+ int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
+ int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
+ int32_t ne0,
+ int32_t nb0, int32_t nb1, int32_t nb2, int32_t nb3
+) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_comp_spv,
+ kp::shader_data::op_mul_comp_spv_len);
+
+ struct PushConstants {
+ uint32_t inAOff, inBOff, outOff;
+ int32_t ne00;
+ int32_t nb00, nb01, nb02, nb03;
+ int32_t ne10, ne11, ne12, ne13;
+ int32_t nb10, nb11, nb12, nb13;
+ int32_t ne0;
+ int32_t nb0, nb1, nb2, nb3;
+ } const pushConsts {
+ safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+ ne00,
+ nb00, nb01, nb02, nb03,
+ ne10, ne11, ne12, ne13,
+ nb10, nb11, nb12, nb13,
+ ne0,
+ nb0, nb1, nb2, nb3
+ };
+
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(__func__)) {
+ s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
+ } else {
+ s_algo = komputeManager()->getAlgorithm(__func__);
+ s_algo->setTensors({inA, inB, out});
+ s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_scale(kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& in,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inOff, uint32_t outOff,
+ uint32_t size, float scale) {
+ const static auto spirv_1 = getSpirvShader(
+ kp::shader_data::op_scale_comp_spv, kp::shader_data::op_scale_comp_spv_len
+ );
+ const static auto spirv_8 = getSpirvShader(
+ kp::shader_data::op_scale_8_comp_spv, kp::shader_data::op_scale_8_comp_spv_len
+ );
+
+ struct PushConstants {
+ uint32_t inOff, outOff;
+ float scale;
+ } const pushConsts {
+ safe_divide(inOff, 4), safe_divide(outOff, 4),
+ scale
+ };
+
+ const auto * spirv = &spirv_1;
+ std::string name(__func__);
+ if (size % 8 == 0) {
+ size /= 8;
+ name += "_8";
+ spirv = &spirv_8;
+ }
+
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(name)) {
+ s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, *spirv, {size}, {}, {pushConsts});
+ } else {
+ s_algo = komputeManager()->getAlgorithm(name);
+ s_algo->setTensors({in, out});
+ s_algo->setWorkgroup({size});
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_xxlu(
+ const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& in,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inOff, uint32_t outOff,
+ uint32_t size
+) {
+ struct PushConstants {
+ uint32_t inOff, outOff;
+ } const pushConsts {
+ safe_divide(inOff, 4), safe_divide(outOff, 4),
+ };
+
+ auto name = std::string(__func__) + "_" + suffix;
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(name)) {
+ s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts});
+ } else {
+ s_algo = komputeManager()->getAlgorithm(name);
+ s_algo->setTensors({in, out});
+ s_algo->setWorkgroup({size});
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+template <typename... Args>
+static void ggml_vk_silu(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_silu_comp_spv,
+ kp::shader_data::op_silu_comp_spv_len);
+
+ ggml_vk_xxlu(spirv, "silu", std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_relu(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_relu_comp_spv,
+ kp::shader_data::op_relu_comp_spv_len);
+
+ ggml_vk_xxlu(spirv, "relu", std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_gelu(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_gelu_comp_spv,
+ kp::shader_data::op_gelu_comp_spv_len);
+
+ ggml_vk_xxlu(spirv, "gelu", std::forward<Args>(args)...);
+}
+
+static void ggml_vk_soft_max(
+ kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& inA,
+ const std::shared_ptr<kp::Tensor>& inB,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+ int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
+ float scale
+) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
+ kp::shader_data::op_softmax_comp_spv_len);
+
+ struct PushConstants {
+ uint32_t inAOff, inBOff, outOff;
+ int32_t ne00, ne01, ne02;
+ float scale;
+ int32_t mask;
+ } pushConsts {
+ safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+ ne00, ne01, ne02,
+ scale,
+ bool(inB)
+ };
+
+ auto & inB_ = inB ? inB : inA;
+
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(__func__)) {
+ // FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device
+ const uint32_t local_x = 32;
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB_, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts});
+ } else {
+ s_algo = komputeManager()->getAlgorithm(__func__);
+ s_algo->setTensors({inA, inB_, out});
+ s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_norm_(
+ const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& in,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inOff, uint32_t outOff,
+ int32_t ne00, int32_t nb01,
+ int32_t nrows, float epsilon
+) {
+ GGML_ASSERT(nb01%sizeof(float) == 0);
+ GGML_ASSERT(ne00%sizeof(float) == 0);
+
+ struct PushConstants {
+ uint32_t inOff, outOff;
+ uint32_t ne00, nb01;
+ float eps;
+ } pushConsts {
+ safe_divide(inOff, 4), safe_divide(outOff, 4),
+ (uint32_t)ne00, (uint32_t)nb01, epsilon
+ };
+
+ auto name = std::string(__func__) + "_" + suffix;
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(name)) {
+ s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {(uint32_t)nrows}, {}, {pushConsts});
+ } else {
+ s_algo = komputeManager()->getAlgorithm(name);
+ s_algo->setTensors({in, out});
+ s_algo->setWorkgroup({(uint32_t)nrows});
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+template <typename... Args>
+static void ggml_vk_norm(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_norm_comp_spv,
+ kp::shader_data::op_norm_comp_spv_len);
+
+ ggml_vk_norm_(spirv, "norm", std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_rms_norm(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_rmsnorm_comp_spv,
+ kp::shader_data::op_rmsnorm_comp_spv_len);
+
+ ggml_vk_norm_(spirv, "rms", std::forward<Args>(args)...);
+}
+
+static void ggml_vk_diag_mask_inf(kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& in,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inOff, uint32_t outOff,
+ uint32_t n_past,
+ int32_t ne00, int32_t ne01, int32_t ne02) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_diagmask_comp_spv,
+ kp::shader_data::op_diagmask_comp_spv_len);
+
+ struct PushConstants {
+ uint32_t inOff, outOff;
+ uint32_t n_past;
+ int32_t ne00, ne01;
+ } pushConsts {
+ safe_divide(inOff, 4), safe_divide(outOff, 4),
+ n_past,
+ ne00, ne01
+ };
+
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(__func__))
+ s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne00), unsigned(ne01), unsigned(ne02)}, {}, {pushConsts});
+ else {
+ s_algo = komputeManager()->getAlgorithm(__func__);
+ s_algo->setTensors({in, out});
+ s_algo->setWorkgroup({unsigned(ne00), unsigned(ne01), unsigned(ne02)});
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_mul_mat_f16(
+ kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& inA,
+ const std::shared_ptr<kp::Tensor>& inB,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+ int32_t ne00, int32_t ne01, int32_t ne02,
+ uint32_t nb00, uint32_t nb01, uint32_t nb02,
+ int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
+ uint32_t nb10, uint32_t nb11, uint32_t nb12,
+ int32_t ne0, int32_t ne1,
+ uint32_t r2, uint32_t r3
+) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_f16_comp_spv,
+ kp::shader_data::op_mul_mat_f16_comp_spv_len);
+
+ struct PushConstants {
+ uint32_t inAOff, inBOff, outOff;
+ int32_t ne00, ne01, ne02;
+ uint32_t nb00, nb01, nb02;
+ int32_t ne10, ne11, ne12;
+ uint32_t nb10, nb11, nb12;
+ int32_t ne0, ne1;
+ uint32_t r2, r3;
+ } pushConsts {
+ safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+ ne00, ne01, ne02,
+ nb00, nb01, nb02,
+ ne10, ne11, ne12,
+ nb10, nb11, nb12,
+ ne0, ne1,
+ r2, r3
+ };
+
+ const unsigned ny = unsigned((ne11 + 4 - 1)/4);
+
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(__func__)) {
+ const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), ny, unsigned(ne12*ne13)}, {local_x}, {pushConsts});
+ } else {
+ s_algo = komputeManager()->getAlgorithm(__func__);
+ s_algo->setTensors({inA, inB, out});
+ s_algo->setWorkgroup({unsigned(ne01), ny, unsigned(ne12*ne13)});
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_mul_mat_mat_f32(kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& inA,
+ const std::shared_ptr<kp::Tensor>& inB,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+ int32_t ne00, int32_t ne01, int32_t ne02,
+ uint32_t nb01, uint32_t nb02,
+ int32_t ne11, int32_t ne12,
+ uint32_t nb11, uint32_t nb12,
+ uint32_t nb1, uint32_t nb2) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_mat_f32_comp_spv,
+ kp::shader_data::op_mul_mat_mat_f32_comp_spv_len);
+
+ struct PushConstants {
+ uint32_t inAOff, inBOff, outOff;
+ int32_t ne00, ne01, ne02, ne11, ne12;
+ uint32_t nb01, nb02;
+ uint32_t nb11, nb12;
+ uint32_t nb1, nb2;
+ } pushConsts {
+ safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+ ne00, ne01, ne02, ne11, ne12,
+ nb01, nb02, nb11, nb12,
+ nb1, nb2
+ };
+
+ const uint32_t local_x = ggml_vk_current_device().subgroupSize;
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(__func__)) {
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(),
+ {inA, inB, out}, spirv,
+ {unsigned(ne01),
+ unsigned(ne11),
+ unsigned(std::max(ne12, ne02))
+ },
+ {local_x},
+ {pushConsts});
+ } else {
+ s_algo = komputeManager()->getAlgorithm(__func__);
+ s_algo->setTensors({inA, inB, out});
+ s_algo->setWorkgroup({unsigned(ne01),
+ unsigned(ne11),
+ unsigned(std::max(ne12, ne02)),
+ });
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_mul_mat_impl(
+ const std::vector<uint32_t>& spirv, const char * suffix, uint32_t block_size, kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& inA,
+ const std::shared_ptr<kp::Tensor>& inB,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+ int32_t ne00, int32_t ne01, int32_t ne02,
+ int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
+ int32_t ne0, int32_t ne1,
+ uint32_t r2, uint32_t r3
+) {
+ struct PushConstants {
+ uint32_t inAOff, inBOff, outOff;
+ int32_t ne00, ne01, ne02;
+ int32_t ne10, ne12;
+ int32_t ne0, ne1;
+ uint32_t r2, r3;
+ } pushConsts {
+ safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+ ne00, ne01, ne02,
+ ne10, ne12,
+ ne0, ne1,
+ r2, r3
+ };
+
+ auto name = std::string(__func__) + "_" + suffix;
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(name)) {
+ const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
+ } else {
+ s_algo = komputeManager()->getAlgorithm(name);
+ s_algo->setTensors({inA, inB, out});
+ s_algo->setWorkgroup({unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)});
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+template <typename... Args>
+static void ggml_vk_mul_mat_q4_0(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_0_comp_spv,
+ kp::shader_data::op_mul_mat_q4_0_comp_spv_len);
+
+ ggml_vk_mul_mat_impl(spirv, "q4_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_mul_mat_q4_1(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_1_comp_spv,
+ kp::shader_data::op_mul_mat_q4_1_comp_spv_len);
+
+ ggml_vk_mul_mat_impl(spirv, "q4_1", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_mul_mat_q8_0(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q8_0_comp_spv,
+ kp::shader_data::op_mul_mat_q8_0_comp_spv_len);
+
+ ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
+}
+
+static void ggml_vk_mul_mat_q6_k(
+ kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& inA,
+ const std::shared_ptr<kp::Tensor>& inB,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+ int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
+ int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02
+) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
+ kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
+
+ struct PushConstants {
+ uint32_t inAOff, inBOff, outOff;
+ int32_t ne00, ne10, ne0, ne1, ne01, gqa;
+ } pushConsts {
+ inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
+ ne00, ne10, ne0, ne1, ne01, ne12/ne02
+ };
+
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(__func__)) {
+ const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
+ } else {
+ s_algo = komputeManager()->getAlgorithm(__func__);
+ s_algo->setTensors({inA, inB, out});
+ s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_get_rows(
+ const std::vector<uint32_t>& spirv,
+ const char * suffix,
+ unsigned element_size, unsigned qk,
+ kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& inA,
+ const std::shared_ptr<kp::Tensor>& inB,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+ int32_t ne00, int32_t nb01, int32_t nb1,
+ uint32_t size
+) {
+ GGML_ASSERT(nb01%element_size == 0);
+ GGML_ASSERT(nb1%sizeof(float) == 0);
+ if (qk) GGML_ASSERT(ne00%qk == 0);
+
+ struct PushConstants {
+ uint32_t inAOff, inBOff, outOff;
+ int32_t ne00, nb01, nb1;
+ } pushConsts {
+ safe_divide(inAOff, element_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
+ ne00, nb01, nb1
+ };
+
+ auto name = std::string(__func__) + "_" + suffix;
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(name)) {
+ s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
+ } else {
+ s_algo = komputeManager()->getAlgorithm(name);
+ s_algo->setTensors({inA, inB, out});
+ s_algo->setWorkgroup({size});
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+template <typename... Args>
+static void ggml_vk_get_rows_f32(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
+ kp::shader_data::op_getrows_f32_comp_spv_len);
+
+ ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_get_rows_f16(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
+ kp::shader_data::op_getrows_f16_comp_spv_len);
+
+ ggml_vk_get_rows(spirv, "f16", sizeof(half), 0, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_get_rows_q4_0(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_0_comp_spv,
+ kp::shader_data::op_getrows_q4_0_comp_spv_len);
+
+ ggml_vk_get_rows(spirv, "q4_0", 1/*We access blocks unaligned*/, QK4_0, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_get_rows_q4_1(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_1_comp_spv,
+ kp::shader_data::op_getrows_q4_1_comp_spv_len);
+
+ ggml_vk_get_rows(spirv, "q4_1", 1/*We access blocks unaligned*/, QK4_1, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_get_rows_q6_k(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q6_k_comp_spv,
+ kp::shader_data::op_getrows_q6_k_comp_spv_len);
+ ggml_vk_get_rows(spirv, "q6_k", 1/*We access blocks unaligned*/, QK_NL, std::forward<Args>(args)...);
+}
+
+static void ggml_vk_rope(
+ kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& inA,
+ const std::shared_ptr<kp::Tensor>& inB,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+ ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
+ float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
+ int32_t ne01, int32_t ne02, int32_t ne03,
+ uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
+ int32_t ne0,
+ uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
+) {
+ GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
+
+ static const auto spirv_f16 = getSpirvShader(
+ kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
+ );
+ static const auto spirv_f32 = getSpirvShader(
+ kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len
+ );
+
+ int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
+
+ GGML_ASSERT(nb03 % type_size == 0);
+ GGML_ASSERT(nb02 % type_size == 0);
+ GGML_ASSERT(nb01 % type_size == 0);
+ GGML_ASSERT(nb00 % type_size == 0);
+ GGML_ASSERT(nb3 % type_size == 0);
+ GGML_ASSERT(nb2 % type_size == 0);
+ GGML_ASSERT(nb1 % type_size == 0);
+ GGML_ASSERT(nb0 % type_size == 0);
+
+ struct PushConstants {
+ uint32_t inAOff, inBOff, outOff;
+ int32_t n_dims, mode, n_ctx_orig;
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+ uint32_t nb00, nb01, nb02, nb03;
+ int32_t ne0;
+ uint32_t nb0, nb1, nb2, nb3;
+ } pushConsts {
+ safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
+ n_dims, mode, n_ctx_orig,
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
+ nb00, nb01, nb02, nb03,
+ ne0,
+ nb0, nb1, nb2, nb3
+ };
+
+ auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(name)) {
+ s_algo = komputeManager()->algorithm<float, PushConstants>(
+ name, s_kompute_context->pool.get(), {inA, inB, out},
+ src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
+ {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
+ );
+ } else {
+ s_algo = komputeManager()->getAlgorithm(name);
+ s_algo->setTensors({inA, inB, out});
+ s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+static void ggml_vk_cpy(
+ const std::vector<uint32_t>& spirv,
+ uint32_t in_element_size, uint32_t out_element_size,
+ kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& in,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inOff, uint32_t outOff,
+ int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
+ uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
+ int32_t ne0, int32_t ne1, int32_t ne2,
+ uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
+) {
+ struct PushConstants {
+ uint32_t inOff, outOff;
+ int32_t ne00, ne01, ne02;
+ uint32_t nb00, nb01, nb02, nb03;
+ int32_t ne0, ne1, ne2;
+ uint32_t nb0, nb1, nb2, nb3;
+ } pushConsts {
+ safe_divide(inOff, in_element_size), safe_divide(outOff, out_element_size),
+ ne00, ne01, ne02,
+ nb00, nb01, nb02, nb03,
+ ne0, ne1, ne2,
+ nb0, nb1, nb2, nb3
+ };
+
+ std::string name = std::string(__func__)
+ + "_i_" + std::to_string(in_element_size)
+ + "_o_" + std::to_string(out_element_size);
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(name))
+ s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
+ else {
+ s_algo = komputeManager()->getAlgorithm(name);
+ s_algo->setTensors({in, out});
+ s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
+template <typename... Args>
+static void ggml_vk_cpy_f32_f16(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f16_comp_spv,
+ kp::shader_data::op_cpy_f32_f16_comp_spv_len);
+ ggml_vk_cpy(spirv, 4, 2, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_cpy_f32_f32(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f32_comp_spv,
+ kp::shader_data::op_cpy_f32_f32_comp_spv_len);
+ ggml_vk_cpy(spirv, 4, 4, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_cpy_f16_f16(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f16_comp_spv,
+ kp::shader_data::op_cpy_f16_f16_comp_spv_len);
+ ggml_vk_cpy(spirv, 2, 2, std::forward<Args>(args)...);
+}
+
+template <typename... Args>
+static void ggml_vk_cpy_f16_f32(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f32_comp_spv,
+ kp::shader_data::op_cpy_f16_f32_comp_spv_len);
+ ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
+}
+
+static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
+ switch (op->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ break;
+ default:
+ return false;
+ }
+
+ switch (op->op) {
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(op)) {
+ case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_SILU:
+ return ggml_is_contiguous(op->src[0]);
+ default:
+ ;
+ }
+ break;
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_ADD:
+ case GGML_OP_MUL:
+ case GGML_OP_SCALE:
+ case GGML_OP_SOFT_MAX:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_NORM:
+ case GGML_OP_ROPE:
+ return true;
+ case GGML_OP_DUP:
+ case GGML_OP_CPY:
+ case GGML_OP_CONT:
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ break;
+ default:
+ return false;
+ }
+ switch (op->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ break;
+ default:
+ return false;
+ }
+ return true;
+ case GGML_OP_DIAG_MASK_INF:
+ return op->ne[3] == 1;
+ case GGML_OP_GET_ROWS:
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q6_K:
+ return op->ne[2] == 1 && op->ne[3] == 1;
+ default:
+ ;
+ }
+ return false;
+ case GGML_OP_MUL_MAT:
+ if (op->src[1]->type != GGML_TYPE_F32 || ggml_is_transposed(op->src[0]) || ggml_is_transposed(op->src[1]))
+ return false;
+
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_Q6_K:
+ return op->ne[3] == 1;
+ case GGML_TYPE_F16:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ return true;
+ default:
+ ;
+ }
+ default:
+ ;
+ }
+ return false;
+}
+
+static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
+ const int n_seq = 8;
+
+ // FIXME: Figure out if we can somehow optimize the size of the pool... right now we're setting
+ // it to the size of the graph, but I think it can be made smaller?
+ ggml_vk_allocate_descriptor_pool(ctx, gf->n_nodes);
+
+ std::vector<std::shared_ptr<kp::Sequence>> sequences(n_seq);
+
+ for (auto& sequence : sequences) {
+ sequence = komputeManager()->sequence();
+ }
+ for (int seq_idx = 0; seq_idx < n_seq; ++seq_idx) {
+ const int n_nodes_per_seq = (gf->n_nodes + n_seq - 1) / n_seq;
+
+ auto& seq = *sequences[seq_idx];
+
+ const int node_start = (seq_idx + 0) * n_nodes_per_seq;
+ const int node_end = std::min((seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq, gf->n_nodes);
+
+ bool any_commands_recorded = false;
+
+ for (int i = node_start; i < node_end; ++i) {
+ struct ggml_tensor * src0 = gf->nodes[i]->src[0];
+ struct ggml_tensor * src1 = gf->nodes[i]->src[1];
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
+ struct ggml_tensor * dst = gf->nodes[i];
+ GGML_ASSERT(dst->data != nullptr);
+
+ if (ggml_is_empty(dst)) {
+ continue;
+ }
+
+ switch (dst->op) {
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_PERMUTE:
+ continue; // noop -> next node
+ default:
+ break;
+ }
+
+ any_commands_recorded = true;
+
+ if (!ggml_vk_supports_op(dst)) {
+ fprintf(stderr, "%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
+ GGML_ASSERT(!"unsupported op");
+ }
+
+ const int32_t ne00 = src0 ? src0->ne[0] : 0;
+ const int32_t ne01 = src0 ? src0->ne[1] : 0;
+ const int32_t ne02 = src0 ? src0->ne[2] : 0;
+ const int32_t ne03 = src0 ? src0->ne[3] : 0;
+
+ const uint32_t nb00 = src0 ? src0->nb[0] : 0;
+ const uint32_t nb01 = src0 ? src0->nb[1] : 0;
+ const uint32_t nb02 = src0 ? src0->nb[2] : 0;
+ const uint32_t nb03 = src0 ? src0->nb[3] : 0;
+
+ const int32_t ne10 = src1 ? src1->ne[0] : 0;
+ const int32_t ne11 = src1 ? src1->ne[1] : 0;
+ const int32_t ne12 = src1 ? src1->ne[2] : 0;
+ const int32_t ne13 = src1 ? src1->ne[3] : 0;
+
+ const uint32_t nb10 = src1 ? src1->nb[0] : 0;
+ const uint32_t nb11 = src1 ? src1->nb[1] : 0;
+ const uint32_t nb12 = src1 ? src1->nb[2] : 0;
+ const uint32_t nb13 = src1 ? src1->nb[3] : 0;
+
+ const int32_t ne0 = dst ? dst->ne[0] : 0;
+ const int32_t ne1 = dst ? dst->ne[1] : 0;
+ const int32_t ne2 = dst ? dst->ne[2] : 0;
+// const int32_t ne3 = dst ? dst->ne[3] : 0;
+
+ const uint32_t nb0 = dst ? dst->nb[0] : 0;
+ const uint32_t nb1 = dst ? dst->nb[1] : 0;
+ const uint32_t nb2 = dst ? dst->nb[2] : 0;
+ const uint32_t nb3 = dst ? dst->nb[3] : 0;
+
+ const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
+ const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+ const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
+
+ const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
+ uint32_t off_src0 = 0;
+ uint32_t off_src1 = 0;
+ uint32_t off_dst = 0;
+ const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
+ const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
+ const std::shared_ptr<kp::Tensor>& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor;
+
+ switch (dst->op) {
+ case GGML_OP_ADD:
+ {
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+ // src1 is a row
+ ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
+ } else {
+ ggml_vk_add(
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+ ne00, ne01, ne02, ne03,
+ nb00, nb01, nb02, nb03,
+ ne10, ne11, ne12, ne13,
+ nb10, nb11, nb12, nb13,
+ ne0,
+ nb0, nb1, nb2, nb3
+ );
+ }
+ } break;
+ case GGML_OP_MUL:
+ {
+ ggml_vk_mul(
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+ ne00, ne01, ne02, ne03,
+ nb00, nb01, nb02, nb03,
+ ne10, ne11, ne12, ne13,
+ nb10, nb11, nb12, nb13,
+ ne0,
+ nb0, nb1, nb2, nb3
+ );
+ } break;
+ case GGML_OP_SCALE:
+ {
+ float scale; memcpy(&scale, dst->op_params, sizeof(float));
+
+ ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale);
+ } break;
+ case GGML_OP_UNARY:
+ {
+ int64_t n = ggml_nelements(dst);
+ GGML_ASSERT(n % 4 == 0);
+ switch (ggml_get_unary_op(gf->nodes[i])) {
+ case GGML_UNARY_OP_SILU:
+ {
+ ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
+ } break;
+ case GGML_UNARY_OP_RELU:
+ {
+ ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
+ } break;
+ case GGML_UNARY_OP_GELU:
+ {
+ GGML_ASSERT(n % 8 == 0);
+ ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, n/8);
+ } break;
+ default:
+ {
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+ GGML_ASSERT(false);
+ }
+ }
+ } break;
+ case GGML_OP_SOFT_MAX:
+ {
+ float scale;
+ float max_bias;
+
+ memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
+
+#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
+#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
+ GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
+
+#pragma message("TODO: add ALiBi support")
+#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
+ GGML_ASSERT(max_bias == 0.0f);
+
+ ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
+ } break;
+ case GGML_OP_DIAG_MASK_INF:
+ {
+ const int n_past = ((int32_t *)(dst->op_params))[0];
+ ggml_vk_diag_mask_inf(seq, id_src0, id_dst, off_src0, off_dst, n_past, ne00, ne01, ne02);
+ } break;
+ case GGML_OP_NORM:
+ {
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+ ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
+ } break;
+ case GGML_OP_RMS_NORM:
+ {
+ GGML_ASSERT(ne00 % 4 == 0);
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+ ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
+ } break;
+ case GGML_OP_MUL_MAT:
+ {
+ GGML_ASSERT(ne00 == ne10);
+
+ GGML_ASSERT(ne12 % ne02 == 0);
+ GGML_ASSERT(ne13 % ne03 == 0);
+
+ const uint32_t r2 = ne12/ne02;
+ const uint32_t r3 = ne13/ne03;
+
+ if (src1t != GGML_TYPE_F32) {
+ fprintf(stderr, "%s: %s: Unsupported src1 type: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
+ goto not_implemented;
+ }
+
+ if (ggml_is_transposed(src0) ||
+ ggml_is_transposed(src1)) {
+ fprintf(stderr, "%s: %s: matmul on tranposed tensor not supported: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
+ goto not_implemented;
+ }
+
+ switch (src0t) {
+ case GGML_TYPE_F32:
+ ggml_vk_mul_mat_mat_f32(
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+ ne00, ne01, ne02, nb01, nb02, ne11, ne12, nb11, nb12, nb1, nb2
+ );
+ break;
+ case GGML_TYPE_F16:
+ ggml_vk_mul_mat_f16(
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+ ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, ne13, nb10, nb11, nb12,
+ ne0, ne1, r2, r3
+ );
+ break;
+ case GGML_TYPE_Q8_0:
+ ggml_vk_mul_mat_q8_0(
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
+ );
+ break;
+ case GGML_TYPE_Q4_0:
+ ggml_vk_mul_mat_q4_0(
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
+ );
+ break;
+ case GGML_TYPE_Q4_1:
+ ggml_vk_mul_mat_q4_1(
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
+ );
+ break;
+ case GGML_TYPE_Q6_K:
+ ggml_vk_mul_mat_q6_k(
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+ ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02
+ );
+ break;
+ default: {
+ fprintf(stderr, "%s: %s: Unsupported quantization: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
+ goto not_implemented;
+ }
+ }
+
+ } break;
+ case GGML_OP_GET_ROWS:
+ {
+ if (src0t == GGML_TYPE_F32) {
+ ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
+ } else if (src0t == GGML_TYPE_F16) {
+ ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
+ } else if (src0t == GGML_TYPE_Q4_0) {
+ ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
+ } else if (src0t == GGML_TYPE_Q4_1) {
+ ggml_vk_get_rows_q4_1(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
+ } else if (src0t == GGML_TYPE_Q6_K) {
+ ggml_vk_get_rows_q6_k(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
+ } else {
+ fprintf(stderr, "%s: %s: Unsupported quantization: %u\n", __func__, ggml_op_name(dst->op), src0t);
+ goto not_implemented;
+ }
+ } break;
+ case GGML_OP_ROPE:
+ {
+#pragma message("TODO: implement phi3 frequency factors support")
+#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
+ GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
+
+#pragma message("TODO: update rope NORM mode to match NEOX mode")
+#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
+
+ GGML_ASSERT(ne10 == ne02);
+ GGML_ASSERT(src0t == dstt);
+ // const int n_past = ((int32_t *) dst->op_params)[0];
+ const int n_dims = ((int32_t *) dst->op_params)[1];
+ const int mode = ((int32_t *) dst->op_params)[2];
+ // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+ ggml_vk_rope(
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
+ ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
+ );
+ } break;
+ case GGML_OP_DUP:
+ case GGML_OP_CPY:
+ case GGML_OP_CONT:
+ {
+ switch (src0t) {
+ case GGML_TYPE_F32:
+ {
+ switch (dstt) {
+ case GGML_TYPE_F16: ggml_vk_cpy_f32_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
+ case GGML_TYPE_F32: ggml_vk_cpy_f32_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
+ default: goto not_implemented;
+ }
+ } break;
+ case GGML_TYPE_F16:
+ {
+ switch (dstt) {
+ case GGML_TYPE_F16: ggml_vk_cpy_f16_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
+ case GGML_TYPE_F32: ggml_vk_cpy_f16_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
+ default: goto not_implemented;
+ } break;
+ default: goto not_implemented;
+ }
+ }
+ } break;
+ default: goto not_implemented;
+ }
+ continue;
+ not_implemented: {}
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+ //GGML_ASSERT(false);
+ }
+
+ // Evaluate sequence
+ if (any_commands_recorded) {
+ seq.evalAsync();
+ }
+ }
+
+ // Wait for all sequences to finish
+ for (auto& sequence : sequences) {
+ if (sequence->isRunning())
+ sequence->evalAwait();
+ }
+
+ ggml_vk_free_descriptor_pool(ctx);
+}
+
+template<>
+kp::Tensor::TensorDataTypes
+kp::TensorT<half>::dataType()
+{
+ return TensorDataTypes::eFloat;
+}
+
+template<>
+kp::Tensor::TensorDataTypes
+kp::TensorT<uint8_t>::dataType()
+{
+ return TensorDataTypes::eUnsignedInt;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// backend interface
+
+struct ggml_backend_kompute_buffer_type_context {
+ int device;
+ int device_ref = 0;
+ uint64_t buffer_alignment;
+ uint64_t max_alloc;
+ std::string name;
+
+ ggml_backend_kompute_buffer_type_context(int device, uint64_t buffer_alignment, uint64_t max_alloc)
+ : device(device), buffer_alignment(buffer_alignment), max_alloc(max_alloc), name(ggml_kompute_format_name(device)) {}
+};
+
+static void ggml_backend_kompute_device_ref(ggml_backend_buffer_type_t buft) {
+ auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
+
+ if (!ctx->device_ref) {
+ komputeManager()->initializeDevice(
+ ctx->device, {}, {
+ "VK_KHR_shader_float16_int8", "VK_KHR_8bit_storage",
+ "VK_KHR_16bit_storage", "VK_KHR_shader_non_semantic_info"
+ }
+ );
+ }
+
+ assert(ggml_vk_has_device());
+ ctx->device_ref++;
+}
+
+static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) {
+ auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
+
+ assert(ctx->device_ref > 0);
+
+ ctx->device_ref--;
+
+ if (!ctx->device_ref) {
+ komputeManager.destroy();
+ }
+}
+
+static const char * ggml_backend_kompute_buffer_get_name(ggml_backend_buffer_t buffer) {
+ auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buffer->buft->context);
+ return ctx->name.c_str();
+}
+
+static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ auto * memory = (ggml_vk_memory *)buffer->context;
+ if (ggml_vk_has_device()) {
+ ggml_vk_free_memory(*memory);
+ }
+ delete memory;
+}
+
+static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) {
+ return ((ggml_vk_memory *)buffer->context)->data;
+}
+
+static void ggml_backend_kompute_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ GGML_UNUSED(buffer);
+
+ const auto res = ggml_vk_get_tensor(tensor);
+ GGML_ASSERT(res);
+
+ memcpy((char *)tensor->data + offset, data, size);
+
+ komputeManager()->sequence()->eval<kp::OpTensorSyncDevice>({res});
+}
+
+static void ggml_backend_kompute_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ GGML_UNUSED(buffer);
+
+ const auto res = ggml_vk_get_tensor(tensor);
+ GGML_ASSERT(res);
+
+ komputeManager()->sequence()->eval<kp::OpTensorSyncLocal>({res});
+
+ memcpy(data, (const char *)tensor->data + offset, size);
+}
+
+static void ggml_backend_kompute_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ auto * memory = (ggml_vk_memory *)buffer->context;
+ memset(memory->data, value, buffer->size);
+
+ if (memory->stagingBuffer)
+ komputeManager()->sequence()->eval<kp::OpBufferSyncDevice>(memory->primaryBuffer, memory->stagingBuffer, memory->size);
+}
+
+static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
+ /* .get_name = */ ggml_backend_kompute_buffer_get_name,
+ /* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_kompute_buffer_get_base,
+ /* .init_tensor = */ NULL,
+ /* .set_tensor = */ ggml_backend_kompute_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_kompute_buffer_get_tensor,
+ /* .cpy_tensor = */ NULL,
+ /* .clear = */ ggml_backend_kompute_buffer_clear,
+ /* .reset = */ NULL,
+};
+
+// default buffer type
+
+static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+ auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
+ return ctx->name.c_str();
+}
+
+static ggml_backend_buffer_t ggml_backend_kompute_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ ggml_backend_kompute_device_ref(buft);
+ auto * ctx = new ggml_vk_memory(ggml_vk_allocate(size));
+ return ggml_backend_buffer_init(buft, ggml_backend_kompute_buffer_i, ctx, size);
+}
+
+static size_t ggml_backend_kompute_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+ auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
+ return ctx->buffer_alignment;
+}
+
+static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
+ auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
+ return ctx->max_alloc;
+}
+
+static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
+ /* .get_name = */ ggml_backend_kompute_buffer_type_get_name,
+ /* .alloc_buffer = */ ggml_backend_kompute_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_kompute_buffer_type_get_alignment,
+ /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size,
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
+ /* .is_host = */ NULL,
+};
+
+ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
+ static std::vector<ggml_backend_buffer_type> bufts = []() {
+ std::vector<ggml_backend_buffer_type> vec;
+ auto devices = ggml_vk_available_devices_internal(0);
+ vec.reserve(devices.size());
+
+ for (const auto & dev : devices) {
+ vec.push_back({
+ /* .iface = */ ggml_backend_kompute_buffer_type_interface,
+ /* .context = */ new ggml_backend_kompute_buffer_type_context(dev.index, dev.bufferAlignment, dev.maxAlloc)
+ });
+ }
+ return vec;
+ }();
+
+ auto it = std::find_if(bufts.begin(), bufts.end(), [device](const ggml_backend_buffer_type & t) {
+ return device == static_cast<ggml_backend_kompute_buffer_type_context *>(t.context)->device;
+ });
+ return it < bufts.end() ? &*it : nullptr;
+}
+
+// backend
+
+static const char * ggml_backend_kompute_name(ggml_backend_t backend) {
+ auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
+ return ctx->name.c_str();
+}
+
+static void ggml_backend_kompute_free(ggml_backend_t backend) {
+ auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
+
+ assert(ctx == s_kompute_context);
+ s_kompute_context = nullptr;
+ if (ctx != nullptr) {
+ delete ctx;
+ }
+
+ delete backend;
+}
+
+static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(ggml_backend_t backend) {
+ auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
+ return ggml_backend_kompute_buffer_type(ctx->device);
+}
+
+static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+ auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
+ ggml_vk_graph_compute(ctx, cgraph);
+ return GGML_STATUS_SUCCESS;
+}
+
+static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+ GGML_UNUSED(backend);
+ return ggml_vk_supports_op(op);
+}
+
+static bool ggml_backend_kompute_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+ GGML_UNUSED(backend);
+ return buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name;
+}
+
+static struct ggml_backend_i kompute_backend_i = {
+ /* .get_name = */ ggml_backend_kompute_name,
+ /* .free = */ ggml_backend_kompute_free,
+ /* .get_default_buffer_type = */ ggml_backend_kompute_get_default_buffer_type,
+ /* .set_tensor_async = */ NULL,
+ /* .get_tensor_async = */ NULL,
+ /* .cpy_tensor_async = */ NULL,
+ /* .synchronize = */ NULL,
+ /* .graph_plan_create = */ NULL,
+ /* .graph_plan_free = */ NULL,
+ /* .graph_plan_update = */ NULL,
+ /* .graph_plan_compute = */ NULL,
+ /* .graph_compute = */ ggml_backend_kompute_graph_compute,
+ /* .supports_op = */ ggml_backend_kompute_supports_op,
+ /* .supports_buft = */ ggml_backend_kompute_supports_buft,
+ /* .offload_op = */ NULL,
+ /* .event_new = */ NULL,
+ /* .event_free = */ NULL,
+ /* .event_record = */ NULL,
+ /* .event_wait = */ NULL,
+ /* .event_synchronize = */ NULL,
+};
+
+static ggml_guid_t ggml_backend_kompute_guid() {
+ static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca };
+ return &guid;
+}
+
+ggml_backend_t ggml_backend_kompute_init(int device) {
+ GGML_ASSERT(s_kompute_context == nullptr);
+ s_kompute_context = new ggml_kompute_context(device);
+
+ ggml_backend_t kompute_backend = new ggml_backend {
+ /* .guid = */ ggml_backend_kompute_guid(),
+ /* .interface = */ kompute_backend_i,
+ /* .context = */ s_kompute_context,
+ };
+
+ return kompute_backend;
+}
+
+bool ggml_backend_is_kompute(ggml_backend_t backend) {
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
+}
+
+static ggml_backend_t ggml_backend_reg_kompute_init(const char * params, void * user_data) {
+ GGML_UNUSED(params);
+ return ggml_backend_kompute_init(intptr_t(user_data));
+}
+
+extern "C" int ggml_backend_kompute_reg_devices();
+
+int ggml_backend_kompute_reg_devices() {
+ auto devices = ggml_vk_available_devices_internal(0);
+ for (const auto & device : devices) {
+ ggml_backend_register(
+ ggml_kompute_format_name(device.index).c_str(),
+ ggml_backend_reg_kompute_init,
+ ggml_backend_kompute_buffer_type(device.index),
+ reinterpret_cast<void *>(intptr_t(device.index))
+ );
+ }
+ return devices.size();
+}
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
new file mode 100644
index 00000000..388c2008
--- /dev/null
+++ b/ggml/src/ggml-metal.m
@@ -0,0 +1,3380 @@
+//
+// Copyright (C) 2023-2024 The ggml authors
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+
+#import "ggml-metal.h"
+
+#import "ggml-backend-impl.h"
+#import "ggml.h"
+
+#import <Foundation/Foundation.h>
+
+#import <Metal/Metal.h>
+
+#undef MIN
+#undef MAX
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+#ifdef GGML_METAL_NDEBUG
+#define GGML_METAL_LOG_INFO(...)
+#define GGML_METAL_LOG_WARN(...)
+#define GGML_METAL_LOG_ERROR(...)
+#else
+#define GGML_METAL_LOG_INFO(...) ggml_metal_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
+#define GGML_METAL_LOG_WARN(...) ggml_metal_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
+#define GGML_METAL_LOG_ERROR(...) ggml_metal_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+#endif
+
+#define UNUSED(x) (void)(x)
+
+struct ggml_metal_kernel {
+ id<MTLComputePipelineState> pipeline;
+};
+
+enum ggml_metal_kernel_type {
+ GGML_METAL_KERNEL_TYPE_ADD,
+ GGML_METAL_KERNEL_TYPE_ADD_4,
+ GGML_METAL_KERNEL_TYPE_ADD_ROW,
+ GGML_METAL_KERNEL_TYPE_MUL,
+ GGML_METAL_KERNEL_TYPE_MUL_4,
+ GGML_METAL_KERNEL_TYPE_MUL_ROW,
+ GGML_METAL_KERNEL_TYPE_DIV,
+ GGML_METAL_KERNEL_TYPE_DIV_4,
+ GGML_METAL_KERNEL_TYPE_DIV_ROW,
+ GGML_METAL_KERNEL_TYPE_REPEAT_F32,
+ GGML_METAL_KERNEL_TYPE_REPEAT_F16,
+ GGML_METAL_KERNEL_TYPE_REPEAT_I32,
+ GGML_METAL_KERNEL_TYPE_REPEAT_I16,
+ GGML_METAL_KERNEL_TYPE_SCALE,
+ GGML_METAL_KERNEL_TYPE_SCALE_4,
+ GGML_METAL_KERNEL_TYPE_CLAMP,
+ GGML_METAL_KERNEL_TYPE_TANH,
+ GGML_METAL_KERNEL_TYPE_RELU,
+ GGML_METAL_KERNEL_TYPE_SIGMOID,
+ GGML_METAL_KERNEL_TYPE_GELU,
+ GGML_METAL_KERNEL_TYPE_GELU_4,
+ GGML_METAL_KERNEL_TYPE_GELU_QUICK,
+ GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
+ GGML_METAL_KERNEL_TYPE_SILU,
+ GGML_METAL_KERNEL_TYPE_SILU_4,
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
+ GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
+ GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
+ GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
+ GGML_METAL_KERNEL_TYPE_RMS_NORM,
+ GGML_METAL_KERNEL_TYPE_GROUP_NORM,
+ GGML_METAL_KERNEL_TYPE_NORM,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
+ //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
+ GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
+ GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
+ GGML_METAL_KERNEL_TYPE_IM2COL_F16,
+ GGML_METAL_KERNEL_TYPE_IM2COL_F32,
+ GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
+ GGML_METAL_KERNEL_TYPE_PAD_F32,
+ GGML_METAL_KERNEL_TYPE_ARANGE_F32,
+ GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
+ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
+ GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
+ GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
+ GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
+ GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
+ GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
+ GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
+ GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
+ GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
+ GGML_METAL_KERNEL_TYPE_CONCAT,
+ GGML_METAL_KERNEL_TYPE_SQR,
+ GGML_METAL_KERNEL_TYPE_SUM_ROWS,
+
+ GGML_METAL_KERNEL_TYPE_COUNT
+};
+
+struct ggml_metal_context {
+ int n_cb;
+
+ id<MTLDevice> device;
+ id<MTLCommandQueue> queue;
+
+ dispatch_queue_t d_queue;
+
+ struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
+
+ bool support_simdgroup_reduction;
+ bool support_simdgroup_mm;
+
+ bool should_capture_next_compute;
+};
+
+// MSL code
+// TODO: move the contents here when ready
+// for now it is easier to work in a separate file
+// static NSString * const msl_library_source = @"see metal.metal";
+
+// Here to assist with NSBundle Path Hack
+@interface GGMLMetalClass : NSObject
+@end
+@implementation GGMLMetalClass
+@end
+
+static void ggml_metal_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
+ fprintf(stderr, "%s", msg);
+
+ UNUSED(level);
+ UNUSED(user_data);
+}
+
+ggml_log_callback ggml_metal_log_callback = ggml_metal_default_log_callback;
+void * ggml_metal_log_user_data = NULL;
+
+GGML_ATTRIBUTE_FORMAT(2, 3)
+static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
+ if (ggml_metal_log_callback != NULL) {
+ va_list args;
+ va_start(args, format);
+ char buffer[128];
+ int len = vsnprintf(buffer, 128, format, args);
+ if (len < 128) {
+ ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
+ } else {
+ char* buffer2 = malloc(len+1);
+ va_end(args);
+ va_start(args, format);
+ vsnprintf(buffer2, len+1, format, args);
+ buffer2[len] = 0;
+ ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
+ free(buffer2);
+ }
+ va_end(args);
+ }
+}
+
+static void * ggml_metal_host_malloc(size_t n) {
+ void * data = NULL;
+
+#if TARGET_OS_OSX
+ kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE);
+ if (err != KERN_SUCCESS) {
+ GGML_METAL_LOG_ERROR("%s: error: vm_allocate failed\n", __func__);
+ return NULL;
+ }
+#else
+ const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
+ if (result != 0) {
+ GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
+ return NULL;
+ }
+#endif
+
+ return data;
+}
+
+static struct ggml_metal_context * ggml_metal_init(int n_cb) {
+ GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
+
+#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
+ // Show all the Metal device instances in the system
+ NSArray * devices = MTLCopyAllDevices();
+ for (id<MTLDevice> device in devices) {
+ GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
+ }
+ [devices release]; // since it was created by a *Copy* C method
+#endif
+
+ // Pick and show default Metal device
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
+ GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
+
+ // Configure context
+ struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
+ ctx->device = device;
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
+ ctx->queue = [ctx->device newCommandQueue];
+ ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
+
+ id<MTLLibrary> metal_library;
+
+ // load library
+ //
+ // - first check if the library is embedded
+ // - then check if the library is in the bundle
+ // - if not found, load the source and compile it
+ // - if that fails, return NULL
+ {
+ NSBundle * bundle = nil;
+#ifdef SWIFT_PACKAGE
+ bundle = SWIFTPM_MODULE_BUNDLE;
+#else
+ bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
+#endif
+
+ NSError * error = nil;
+
+#if GGML_METAL_EMBED_LIBRARY
+ const bool try_metallib = false;
+#else
+ const bool try_metallib = true;
+#endif
+
+ NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
+ if (try_metallib && path_lib != nil) {
+ // pre-compiled library found
+ NSURL * libURL = [NSURL fileURLWithPath:path_lib];
+ GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
+
+ metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
+ if (error) {
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
+ return NULL;
+ }
+ } else {
+#if GGML_METAL_EMBED_LIBRARY
+ GGML_METAL_LOG_INFO("%s: using embedded metal library\n", __func__);
+
+ extern const char ggml_metallib_start[];
+ extern const char ggml_metallib_end[];
+
+ NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
+#else
+ GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
+
+ NSString * path_source;
+ NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
+
+ GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
+
+ if (path_resource) {
+ path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
+ } else {
+ path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+ }
+
+ if (path_source == nil) {
+ GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
+ path_source = @"ggml-metal.metal";
+ }
+
+ GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
+
+ NSString * src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
+ if (error) {
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
+ return NULL;
+ }
+#endif // GGML_METAL_EMBED_LIBRARY
+
+ @autoreleasepool {
+ // dictionary of preprocessor macros
+ NSMutableDictionary * prep = [NSMutableDictionary dictionary];
+
+ MTLCompileOptions* options = [MTLCompileOptions new];
+ options.preprocessorMacros = prep;
+
+ //[options setFastMathEnabled:false];
+
+ metal_library = [ctx->device newLibraryWithSource:src options:options error:&error];
+ if (error) {
+ GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
+ return NULL;
+ }
+ }
+ }
+ }
+
+ // print MTL GPU family:
+ GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
+
+ const NSInteger MTLGPUFamilyMetal3 = 5001;
+
+ // determine max supported GPU family
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
+ {
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
+ if ([ctx->device supportsFamily:i]) {
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
+ break;
+ }
+ }
+
+ for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
+ if ([ctx->device supportsFamily:i]) {
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
+ break;
+ }
+ }
+
+ for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
+ if ([ctx->device supportsFamily:i]) {
+ GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
+ break;
+ }
+ }
+ }
+
+ ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7];
+ ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
+
+ ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
+
+ GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false");
+ GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
+ GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+
+ ctx->should_capture_next_compute = false;
+
+#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
+ if (@available(macOS 10.12, iOS 16.0, *)) {
+ GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
+ }
+#elif TARGET_OS_OSX
+ if (ctx->device.maxTransferRate != 0) {
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
+ } else {
+ GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
+ }
+#endif
+
+ // load kernels
+ {
+ NSError * error = nil;
+
+ for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
+ ctx->kernels[i].pipeline = nil;
+ }
+
+ /*
+ GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
+ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
+ (int) kernel->pipeline.threadExecutionWidth); \
+ */
+#define GGML_METAL_ADD_KERNEL(e, name, supported) \
+ if (supported) { \
+ struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
+ id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
+ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
+ [metal_function release]; \
+ if (error) { \
+ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
+ [metal_library release]; \
+ return NULL; \
+ } \
+ } else { \
+ GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
+ }
+
+ // simd_sum and simd_max requires MTLGPUFamilyApple7
+
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_4, add_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_4, mul_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_4, div_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN, get_rows_iq1_bn, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, get_rows_iq2_bn, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32, mul_mv_iq1_bn_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, mul_mv_iq2_bn_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32, mul_mv_id_iq1_bn_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, mul_mv_id_iq2_bn_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32, mul_mm_iq1_bn_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, mul_mm_iq2_bn_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32, mul_mm_id_iq1_bn_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, mul_mm_id_iq2_bn_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
+ }
+
+ [metal_library release];
+ return ctx;
+}
+
+static void ggml_metal_free(struct ggml_metal_context * ctx) {
+ GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
+
+ for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
+ [ctx->kernels[i].pipeline release];
+ }
+
+ [ctx->queue release];
+ [ctx->device release];
+
+ dispatch_release(ctx->d_queue);
+
+ free(ctx);
+}
+
+// temporarily defined here for compatibility between ggml-backend and the old API
+
+struct ggml_backend_metal_buffer {
+ void * data;
+ size_t size;
+
+ id<MTLBuffer> metal;
+};
+
+struct ggml_backend_metal_buffer_context {
+ void * all_data;
+ size_t all_size;
+ bool owned;
+
+ // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
+ int n_buffers;
+ struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
+};
+
+// finds the Metal buffer that contains the tensor data on the GPU device
+// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
+// Metal buffer based on the host memory pointer
+//
+static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) {
+ //GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
+
+ const int64_t tsize = ggml_nbytes(t);
+
+ ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
+
+ struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
+
+ // find the view that contains the tensor fully
+ for (int i = 0; i < buf_ctx->n_buffers; ++i) {
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
+
+ //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
+ if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
+ *offs = (size_t) ioffs;
+
+ //GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
+
+ return buf_ctx->buffers[i].metal;
+ }
+ }
+
+ GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
+
+ return nil;
+}
+
+static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
+ for (size_t i = 0, n = 3; i < n; ++i) {
+ if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
+ return false;
+ }
+ }
+
+ switch (op->op) {
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(op)) {
+ case GGML_UNARY_OP_TANH:
+ case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_SIGMOID:
+ case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_GELU_QUICK:
+ case GGML_UNARY_OP_SILU:
+ return ggml_is_contiguous(op->src[0]);
+ default:
+ return false;
+ }
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_CONCAT:
+ case GGML_OP_ADD:
+ case GGML_OP_ACC:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_REPEAT:
+ case GGML_OP_SCALE:
+ case GGML_OP_CLAMP:
+ case GGML_OP_SQR:
+ case GGML_OP_SUM_ROWS:
+ return true;
+ case GGML_OP_SOFT_MAX:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_GROUP_NORM:
+ return ctx->support_simdgroup_reduction;
+ case GGML_OP_NORM:
+ case GGML_OP_ROPE:
+ case GGML_OP_IM2COL:
+ return true;
+ case GGML_OP_POOL_1D:
+ case GGML_OP_POOL_2D:
+ return false;
+ case GGML_OP_UPSCALE:
+ case GGML_OP_PAD:
+ case GGML_OP_ARANGE:
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ case GGML_OP_ARGSORT:
+ case GGML_OP_LEAKY_RELU:
+ return true;
+ case GGML_OP_FLASH_ATTN_EXT:
+ if (op->src[1]->type != GGML_TYPE_F16) {
+ return false;
+ }
+ if (op->src[2]->type != GGML_TYPE_F16) {
+ return false;
+ }
+ if (op->src[0]->ne[0] == 256) {
+ return false;
+ }
+ return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
+ case GGML_OP_MUL_MAT:
+ case GGML_OP_MUL_MAT_ID:
+ return ctx->support_simdgroup_reduction &&
+ (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
+ case GGML_OP_CPY:
+ case GGML_OP_DUP:
+ case GGML_OP_CONT:
+ {
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F32:
+ switch (op->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_IQ4_NL:
+ return true;
+ default:
+ return false;
+ }
+ case GGML_TYPE_F16:
+ switch (op->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ return true;
+ default:
+ return false;
+ }
+ default:
+ return false;
+ };
+ }
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_GET_ROWS:
+ {
+ return op->ne[3] == 1;
+ }
+ default:
+ return false;
+ }
+}
+
+static enum ggml_status ggml_metal_graph_compute(
+ struct ggml_metal_context * ctx,
+ struct ggml_cgraph * gf) {
+
+ @autoreleasepool {
+ MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
+ edesc.dispatchType = MTLDispatchTypeSerial;
+
+ // create multiple command buffers and enqueue them
+ // then, we encode the graph into the command buffers in parallel
+
+ const int n_nodes = gf->n_nodes;
+ const int n_cb = ctx->n_cb;
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
+
+ const bool should_capture = ctx->should_capture_next_compute;
+ if (should_capture) {
+ ctx->should_capture_next_compute = false;
+
+ MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
+ descriptor.captureObject = ctx->queue;
+
+ NSError * error = nil;
+ if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
+ GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
+ GGML_ASSERT(!"capture failed");
+ }
+ }
+
+ id<MTLCommandBuffer> command_buffer_builder[n_cb];
+ for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
+ id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
+ command_buffer_builder[cb_idx] = command_buffer;
+
+ // enqueue the command buffers in order to specify their execution order
+ [command_buffer enqueue];
+ }
+
+ const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
+
+ dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
+ const int cb_idx = iter;
+
+ size_t offs_src0 = 0;
+ size_t offs_src1 = 0;
+ size_t offs_src2 = 0;
+ size_t offs_dst = 0;
+
+ id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
+
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
+ const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
+
+ for (int i = node_start; i < node_end; ++i) {
+ if (i == -1) {
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
+ continue;
+ }
+
+ //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
+
+ struct ggml_tensor * src0 = gf->nodes[i]->src[0];
+ struct ggml_tensor * src1 = gf->nodes[i]->src[1];
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
+ struct ggml_tensor * dst = gf->nodes[i];
+
+ if (ggml_is_empty(dst)) {
+ continue;
+ }
+
+ switch (dst->op) {
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_PERMUTE:
+ {
+ // noop -> next node
+ } continue;
+ default:
+ {
+ } break;
+ }
+
+ if (!ggml_metal_supports_op(ctx, dst)) {
+ GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
+ GGML_ASSERT(!"unsupported op");
+ }
+
+ if (should_capture) {
+ [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
+ }
+
+ const int64_t ne00 = src0 ? src0->ne[0] : 0;
+ const int64_t ne01 = src0 ? src0->ne[1] : 0;
+ const int64_t ne02 = src0 ? src0->ne[2] : 0;
+ const int64_t ne03 = src0 ? src0->ne[3] : 0;
+
+ const uint64_t nb00 = src0 ? src0->nb[0] : 0;
+ const uint64_t nb01 = src0 ? src0->nb[1] : 0;
+ const uint64_t nb02 = src0 ? src0->nb[2] : 0;
+ const uint64_t nb03 = src0 ? src0->nb[3] : 0;
+
+ const int64_t ne10 = src1 ? src1->ne[0] : 0;
+ const int64_t ne11 = src1 ? src1->ne[1] : 0;
+ const int64_t ne12 = src1 ? src1->ne[2] : 0;
+ const int64_t ne13 = src1 ? src1->ne[3] : 0;
+
+ const uint64_t nb10 = src1 ? src1->nb[0] : 0;
+ const uint64_t nb11 = src1 ? src1->nb[1] : 0;
+ const uint64_t nb12 = src1 ? src1->nb[2] : 0;
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0;
+
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
+ const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
+
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0;
+
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
+
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
+
+ const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
+ const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+ const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
+
+ id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
+ id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
+ id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
+ id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
+
+ //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
+ //if (src0) {
+ // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
+ // ggml_is_contiguous(src0), src0->name);
+ //}
+ //if (src1) {
+ // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
+ // ggml_is_contiguous(src1), src1->name);
+ //}
+ //if (dst) {
+ // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
+ // dst->name);
+ //}
+
+ switch (dst->op) {
+ case GGML_OP_CONCAT:
+ {
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
+
+ const int32_t dim = ((int32_t *) dst->op_params)[0];
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
+
+ const int nth = MIN(1024, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ADD:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ {
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+
+ const size_t offs = 0;
+
+ bool bcast_row = false;
+
+ int64_t nb = ne00; // used by the "row" kernels
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ if (dst->op == GGML_OP_MUL && ggml_nelements(src1) == 1 && ggml_is_contiguous(src0)) {
+ float scale;
+ memcpy(&scale, src1->data, sizeof(float));
+ //printf("Replacing op_mul with op_scale. scale = %g\n", (double)scale);
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
+
+ int64_t n = ggml_nelements(dst);
+
+ if (n % 4 == 0) {
+ n /= 4;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ break;
+ }
+ else if (ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && ggml_is_contiguous(dst) &&
+ dst->src[0]->ne[0] == dst->src[1]->ne[0] && dst->src[0]->ne[0] == dst->ne[0] &&
+ dst->src[0]->ne[1] == dst->src[1]->ne[1] && dst->src[0]->ne[1] == dst->ne[1] &&
+ dst->src[0]->ne[2] == dst->src[1]->ne[2] && dst->src[0]->ne[2] == dst->ne[2] &&
+ dst->src[0]->ne[3] == dst->src[1]->ne[3] && ggml_nelements(dst)%4 == 0) {
+
+ switch (dst->op) {
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_4].pipeline; break;
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_4].pipeline; break;
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_4].pipeline; break;
+ default: GGML_ASSERT(false);
+ }
+
+ int64_t n = ggml_nelements(dst)/4;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ break;
+ }
+ else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ // src1 is a row
+ GGML_ASSERT(ne11 == 1);
+
+ nb = ne00 / 4;
+ switch (dst->op) {
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
+ default: GGML_ASSERT(false);
+ }
+
+ bcast_row = true;
+ } else {
+ switch (dst->op) {
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
+ default: GGML_ASSERT(false);
+ }
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
+
+ if (bcast_row) {
+ const int64_t n = ggml_nelements(dst)/4;
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } else {
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ }
+ } break;
+ case GGML_OP_REPEAT:
+ {
+ id<MTLComputePipelineState> pipeline;
+
+ switch (src0t) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
+ case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
+ default: GGML_ASSERT(false);
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
+
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ACC:
+ {
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
+ const size_t offs = ((int32_t *) dst->op_params)[3];
+
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+ if (!inplace) {
+ // run a separete kernel to cpy src->dst
+ // not sure how to avoid this
+ // TODO: make a simpler cpy_bytes kernel
+
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
+
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ }
+
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_SCALE:
+ {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ float scale;
+ memcpy(&scale, dst->op_params, sizeof(scale));
+
+ int64_t n = ggml_nelements(dst);
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ if (n % 4 == 0) {
+ n /= 4;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_CLAMP:
+ {
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
+
+ float min;
+ float max;
+ memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
+ memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&min length:sizeof(min) atIndex:2];
+ [encoder setBytes:&max length:sizeof(max) atIndex:3];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(gf->nodes[i])) {
+ // we are not taking into account the strides, so for now require contiguous tensors
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ case GGML_UNARY_OP_TANH:
+ {
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_RELU:
+ {
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_SIGMOID:
+ {
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_GELU:
+ {
+ int64_t n = ggml_nelements(dst);
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ if (n % 4 == 0) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
+ n /= 4;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_GELU_QUICK:
+ {
+ int64_t n = ggml_nelements(dst);
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ if (n % 4 == 0) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
+ n /= 4;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_SILU:
+ {
+ int64_t n = ggml_nelements(dst);
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ if (n % 4 == 0) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
+ n /= 4;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ default:
+ {
+ GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+ GGML_ASSERT(false);
+ }
+ } break;
+ case GGML_OP_SQR:
+ {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_SUM_ROWS:
+ {
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_SOFT_MAX:
+ {
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
+
+ int nth = 32; // SIMD width
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+
+ if (ne00%4 == 0) {
+ while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
+ nth *= 2;
+ }
+ if (use_f16) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
+ }
+ } else {
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
+ nth *= 2;
+ }
+ if (use_f16) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
+ }
+ }
+
+ float scale;
+ float max_bias;
+
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
+
+ const int64_t nrows_x = ggml_nrows(src0);
+ const int64_t nrows_y = src0->ne[1];
+
+ const uint32_t n_head = nrows_x/nrows_y;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ if (id_src1) {
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ }
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_DIAG_MASK_INF:
+ {
+ const int n_past = ((int32_t *)(dst->op_params))[0];
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ if (ne00%8 == 0) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
+
+ if (ne00%8 == 0) {
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ }
+ else {
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ }
+ } break;
+ case GGML_OP_MUL_MAT:
+ {
+ GGML_ASSERT(ne00 == ne10);
+
+ GGML_ASSERT(ne12 % ne02 == 0);
+ GGML_ASSERT(ne13 % ne03 == 0);
+
+ const uint r2 = ne12/ne02;
+ const uint r3 = ne13/ne03;
+
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+ // to the matrix-vector kernel
+ int ne11_mm_min = 1;
+
+#if 0
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
+ // these numbers do not translate to other devices or model sizes
+ // TODO: need to find a better approach
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
+ switch (src0t) {
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
+ case GGML_TYPE_Q5_0: // not tested yet
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
+ default: ne11_mm_min = 1; break;
+ }
+ }
+#endif
+
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+ !ggml_is_transposed(src0) &&
+ !ggml_is_transposed(src1) &&
+ src1t == GGML_TYPE_F32 &&
+ ne00 % 32 == 0 && ne00 >= 64 &&
+ (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+ // some Metal matrix data types require aligned pointers
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
+ switch (src0->type) {
+ case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
+ case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
+ default: break;
+ }
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+ } else {
+ int nth0 = 32;
+ int nth1 = 1;
+ int nrows = 1;
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ // use custom matrix x vector kernel
+ switch (src0t) {
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
+ nrows = 4;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ nth0 = 32;
+ nth1 = 1;
+ if (src1t == GGML_TYPE_F32) {
+ if (ne11 * ne12 < 4) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
+ nrows = ne11;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
+ nrows = 4;
+ }
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
+ nrows = 4;
+ }
+ } break;
+ case GGML_TYPE_Q4_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q4_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q5_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q5_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q8_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q2_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q3_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q4_K:
+ {
+ nth0 = 4; //1;
+ nth1 = 8; //32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q5_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q6_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_XXS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_XS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ3_XXS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ3_S:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_S:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ1_S:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ1_M:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ1_BN:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_BN_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_BN:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ4_NL:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ4_XS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
+ } break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
+ GGML_ASSERT(false && "not implemented");
+ }
+ };
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
+
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
+ src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
+ const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
+ const int mem_size = 32*sizeof(float);
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q3_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q5_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q6_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ } else {
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ }
+ } break;
+ case GGML_OP_MUL_MAT_ID:
+ {
+ const int n_as = src0->ne[2];
+
+ // src2 = ids
+ const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
+
+ GGML_ASSERT(src2t == GGML_TYPE_I32);
+
+ GGML_ASSERT(!ggml_is_transposed(src0));
+ GGML_ASSERT(!ggml_is_transposed(src1));
+
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+ // to the matrix-vector kernel
+ // ne20 = n_used_experts
+ // ne21 = n_rows
+ const int dst_rows = ne20*ne21;
+ const int dst_rows_min = n_as;
+ const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
+
+ // max size of the rowids array in the kernel shared buffer
+ GGML_ASSERT(dst_rows <= dst_rows_max);
+
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
+ // !!!
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
+ // indirect matrix multiplication
+ // !!!
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+ ne00 % 32 == 0 && ne00 >= 64 &&
+ dst_rows > dst_rows_min) {
+
+ // some Metal matrix data types require aligned pointers
+ // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
+ switch (src0->type) {
+ case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
+ case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
+ default: break;
+ }
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
+ default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
+
+ [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+ } else {
+ int nth0 = 32;
+ int nth1 = 1;
+ int nrows = 1;
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ // use custom matrix x vector kernel
+ switch (src0t) {
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ nth0 = 32;
+ nth1 = 1;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q4_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q4_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q5_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q5_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q8_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q2_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q3_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q4_K:
+ {
+ nth0 = 4; //1;
+ nth1 = 8; //32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q5_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_Q6_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_XXS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_XS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ3_XXS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ3_S:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_S:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ1_S:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ1_M:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ1_BN:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ2_BN:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ4_NL:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
+ } break;
+ case GGML_TYPE_IQ4_XS:
+ {
+ nth0 = 4;
+ nth1 = 16;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
+ } break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
+ GGML_ASSERT(false && "not implemented");
+ }
+ };
+
+ if (ggml_is_quantized(src0t)) {
+ GGML_ASSERT(ne00 >= nth0*nth1);
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
+
+ const int64_t _ne1 = 1;
+ const int tgz = dst_rows;
+
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S||
+ src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
+ const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
+ const int mem_size = 32*sizeof(float);
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q3_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q5_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src0t == GGML_TYPE_Q6_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ } else {
+ const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ }
+ } break;
+ case GGML_OP_GET_ROWS:
+ {
+ id<MTLComputePipelineState> pipeline = nil;
+
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break;
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break;
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
+ case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break;
+ case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break;
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
+ default: GGML_ASSERT(false && "not implemented");
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
+ } break;
+ case GGML_OP_RMS_NORM:
+ {
+ GGML_ASSERT(ne00 % 4 == 0);
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ int nth = 32; // SIMD width
+
+ while (nth < ne00/4 && nth < 1024) {
+ nth *= 2;
+ }
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+ const int64_t nrows = ggml_nrows(src0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_GROUP_NORM:
+ {
+ GGML_ASSERT(ne00 % 4 == 0);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ //float eps;
+ //memcpy(&eps, dst->op_params, sizeof(float));
+
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
+
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
+
+ int nth = 32; // SIMD width
+
+ //while (nth < ne00/4 && nth < 1024) {
+ // nth *= 2;
+ //}
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_NORM:
+ {
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ const int nth = MIN(256, ne00);
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
+ [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
+
+ const int64_t nrows = ggml_nrows(src0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ROPE:
+ {
+ GGML_ASSERT(ne10 == ne02);
+
+ const int nth = MIN(1024, ne00);
+
+ const int n_past = ((int32_t *) dst->op_params)[0];
+ const int n_dims = ((int32_t *) dst->op_params)[1];
+ const int mode = ((int32_t *) dst->op_params)[2];
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+ float freq_base;
+ float freq_scale;
+ float ext_factor;
+ float attn_factor;
+ float beta_fast;
+ float beta_slow;
+
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+
+ const bool is_neox = mode & 2;
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ if (!is_neox) {
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
+ default: GGML_ASSERT(false);
+ };
+ } else {
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
+ default: GGML_ASSERT(false);
+ };
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ if (id_src2 != nil) {
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
+ }
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
+ [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22];
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_IM2COL:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
+
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+ const int32_t N = src1->ne[is_2D ? 3 : 2];
+ const int32_t IC = src1->ne[is_2D ? 2 : 1];
+ const int32_t IH = is_2D ? src1->ne[1] : 1;
+ const int32_t IW = src1->ne[0];
+
+ const int32_t KH = is_2D ? src0->ne[1] : 1;
+ const int32_t KW = src0->ne[0];
+
+ const int32_t OH = is_2D ? dst->ne[2] : 1;
+ const int32_t OW = dst->ne[1];
+
+ const int32_t CHW = IC * KH * KW;
+
+ const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
+ const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ switch (dst->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
+ default: GGML_ASSERT(false);
+ };
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
+ [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
+ [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
+ [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
+ [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
+ [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
+ [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
+ [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
+ [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
+ [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
+ } break;
+ case GGML_OP_UPSCALE:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ const float sf0 = (float)ne0/src0->ne[0];
+ const float sf1 = (float)ne1/src0->ne[1];
+ const float sf2 = (float)ne2/src0->ne[2];
+ const float sf3 = (float)ne3/src0->ne[3];
+
+ const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
+ [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
+ [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
+ [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
+ [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
+
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_PAD:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
+
+ const int nth = MIN(1024, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ARANGE:
+ {
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ float start;
+ float step;
+
+ memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
+ memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
+ [encoder setBytes:&start length:sizeof(start) atIndex:2];
+ [encoder setBytes:&step length:sizeof(step) atIndex:3];
+
+ const int nth = MIN(1024, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ const int dim = dst->op_params[0];
+ const int max_period = dst->op_params[1];
+
+ const int half = dim / 2;
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:3];
+ [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
+
+ const int nth = MIN(1024, half);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ARGSORT:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+ const int nrows = ggml_nrows(src0);
+
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+
+ // bitonic sort requires the number of elements to be power of 2
+ int64_t ne00_padded = 1;
+ while (ne00_padded < ne00) {
+ ne00_padded *= 2;
+ }
+
+ // Metal kernels require the buffer size to be multiple of 16 bytes
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
+ const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ switch (order) {
+ case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break;
+ case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
+ default: GGML_ASSERT(false);
+ };
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
+ } break;
+ case GGML_OP_LEAKY_RELU:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ float slope;
+ memcpy(&slope, dst->op_params, sizeof(float));
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ {
+ GGML_ASSERT(ne00 % 4 == 0);
+ GGML_ASSERT(ne11 % 32 == 0);
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_are_same_shape (src1, src2));
+
+ struct ggml_tensor * src3 = gf->nodes[i]->src[3];
+
+ size_t offs_src3 = 0;
+
+ id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
+
+ GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
+ GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
+ "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
+
+ const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
+ //const int64_t ne31 = src3 ? src3->ne[1] : 0;
+ const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
+ const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
+
+ const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
+ const uint64_t nb31 = src3 ? src3->nb[1] : 0;
+ const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
+ const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
+
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+
+ float scale;
+ float max_bias;
+
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
+
+ const uint32_t n_head = src0->ne[2];
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ bool use_vec_kernel = false;
+
+ if (ne01 >= 4 || (ne00%128 != 0)) {
+ switch (ne00) {
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
+ GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+ GGML_ASSERT(false && "add template specialization for this size");
+ }
+ }
+ } else {
+ use_vec_kernel = true;
+
+ switch (ne00) {
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
+ GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+ GGML_ASSERT(false && "add template specialization for this size");
+ }
+ }
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+ if (id_src3) {
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
+ }
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
+ [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
+ [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
+ [encoder setBytes:&scale length:sizeof( float) atIndex:23];
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
+
+ if (!use_vec_kernel) {
+ // half8x8 kernel
+ const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
+ const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+
+ GGML_ASSERT(nqptg <= 32);
+ GGML_ASSERT(nqptg % 8 == 0);
+ GGML_ASSERT(ncpsg % 32 == 0);
+
+ int64_t nsgmax = 2;
+
+ while (true) {
+ const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
+ if (smem > ctx->device.maxThreadgroupMemoryLength) {
+ break;
+ }
+ nsgmax *= 2;
+ }
+ nsgmax /= 2;
+
+ // simdgroups per threadgroup (a.k.a. warps)
+ const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+
+ const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
+
+ //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
+ GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+
+ [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ } else {
+ // half1x4 kernel
+ const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
+ const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+
+ GGML_ASSERT(nqptg <= 32);
+ GGML_ASSERT(nqptg % 1 == 0);
+ GGML_ASSERT(ncpsg % 32 == 0);
+
+ // simdgroups per threadgroup (a.k.a. warps)
+ const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
+
+ int64_t nsg = 1;
+ while (nsg <= nsgt) {
+ nsg *= 2;
+ }
+ nsg /= 2;
+
+ const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
+
+ //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
+ GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+ [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ }
+ } break;
+ case GGML_OP_DUP:
+ case GGML_OP_CPY:
+ case GGML_OP_CONT:
+ {
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
+
+ id<MTLComputePipelineState> pipeline = nil;
+
+ switch (src0t) {
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
+
+ switch (dstt) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
+ default: GGML_ASSERT(false && "not implemented");
+ };
+ } break;
+ case GGML_TYPE_F16:
+ {
+ switch (dstt) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
+ default: GGML_ASSERT(false && "not implemented");
+ };
+ } break;
+ default: GGML_ASSERT(false && "not implemented");
+ }
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+ GGML_ASSERT(false);
+ }
+ }
+
+ if (should_capture) {
+ [encoder popDebugGroup];
+ }
+ }
+
+ [encoder endEncoding];
+
+ [command_buffer commit];
+ });
+
+ // Wait for completion and check status of each command buffer
+ // needed to detect if the device ran out-of-memory for example (#1881)
+
+ for (int i = 0; i < n_cb; ++i) {
+ id<MTLCommandBuffer> command_buffer = command_buffers[i];
+ [command_buffer waitUntilCompleted];
+
+ MTLCommandBufferStatus status = [command_buffer status];
+ if (status != MTLCommandBufferStatusCompleted) {
+ GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
+ if (status == MTLCommandBufferStatusError) {
+ NSString * error_code = [command_buffer error].localizedDescription;
+ GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]);
+ }
+
+ return GGML_STATUS_FAILED;
+ }
+ }
+
+ if (should_capture) {
+ [[MTLCaptureManager sharedCaptureManager] stopCapture];
+ }
+
+ }
+ return GGML_STATUS_SUCCESS;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// backend interface
+
+// default buffer
+static id<MTLDevice> g_backend_device = nil;
+static int g_backend_device_ref_count = 0;
+
+static id<MTLDevice> ggml_backend_metal_get_device(void) {
+ if (g_backend_device == nil) {
+ g_backend_device = MTLCreateSystemDefaultDevice();
+ }
+
+ g_backend_device_ref_count++;
+
+ return g_backend_device;
+}
+
+static void ggml_backend_metal_free_device(void) {
+ assert(g_backend_device_ref_count > 0);
+
+ g_backend_device_ref_count--;
+
+ if (g_backend_device_ref_count == 0) {
+ [g_backend_device release];
+ g_backend_device = nil;
+ }
+}
+
+GGML_CALL static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
+ return "Metal";
+
+ UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+ for (int i = 0; i < ctx->n_buffers; i++) {
+ [ctx->buffers[i].metal release];
+ }
+ ggml_backend_metal_free_device();
+
+ if (ctx->owned) {
+#if TARGET_OS_OSX
+ vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size);
+#else
+ free(ctx->all_data);
+#endif
+ }
+
+ free(ctx);
+}
+
+GGML_CALL static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+ return ctx->all_data;
+}
+
+GGML_CALL static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ memcpy((char *)tensor->data + offset, data, size);
+
+ UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ memcpy(data, (const char *)tensor->data + offset, size);
+
+ UNUSED(buffer);
+}
+
+GGML_CALL static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+ if (ggml_backend_buffer_is_host(src->buffer)) {
+ memcpy(dst->data, src->data, ggml_nbytes(src));
+ return true;
+ }
+ return false;
+
+ UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+ memset(ctx->all_data, value, ctx->all_size);
+}
+
+static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
+ /* .get_name = */ ggml_backend_metal_buffer_get_name,
+ /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_metal_buffer_get_base,
+ /* .init_tensor = */ NULL,
+ /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
+ /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
+ /* .clear = */ ggml_backend_metal_buffer_clear,
+ /* .reset = */ NULL,
+};
+
+// default buffer type
+
+GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+ return "Metal";
+
+ UNUSED(buft);
+}
+
+static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
+#ifndef GGML_METAL_NDEBUG
+#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
+ if (@available(macOS 10.12, iOS 16.0, *)) {
+ GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)",
+ __func__,
+ size_aligned / 1024.0 / 1024.0,
+ device.currentAllocatedSize / 1024.0 / 1024.0,
+ device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
+
+ if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
+ GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
+ } else {
+ GGML_METAL_LOG_INFO("\n");
+ }
+ } else {
+ GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
+ __func__,
+ size_aligned / 1024.0 / 1024.0,
+ device.currentAllocatedSize / 1024.0 / 1024.0);
+ }
+#endif
+#endif
+ UNUSED(device);
+ UNUSED(size_aligned);
+}
+
+GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
+
+ const size_t size_page = sysconf(_SC_PAGESIZE);
+
+ size_t size_aligned = size;
+ if ((size_aligned % size_page) != 0) {
+ size_aligned += (size_page - (size_aligned % size_page));
+ }
+
+ id<MTLDevice> device = ggml_backend_metal_get_device();
+
+ ctx->all_data = ggml_metal_host_malloc(size_aligned);
+ ctx->all_size = size_aligned;
+ ctx->owned = true;
+ ctx->n_buffers = 1;
+
+ if (ctx->all_data != NULL) {
+ ctx->buffers[0].data = ctx->all_data;
+ ctx->buffers[0].size = size;
+ ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
+ length:size_aligned
+ options:MTLResourceStorageModeShared
+ deallocator:nil];
+ }
+
+ if (ctx->all_data == NULL || ctx->buffers[0].metal == nil) {
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
+ free(ctx);
+ ggml_backend_metal_free_device();
+ return NULL;
+ }
+
+ //ggml_backend_metal_log_allocated_size(device, size_aligned);
+
+ return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
+}
+
+GGML_CALL static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+ return 32;
+ UNUSED(buft);
+}
+
+GGML_CALL static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
+ id<MTLDevice> device = ggml_backend_metal_get_device();
+ size_t max_size = device.maxBufferLength;
+ ggml_backend_metal_free_device();
+
+ return max_size;
+
+ UNUSED(buft);
+}
+
+GGML_CALL static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+ return true;
+
+ UNUSED(buft);
+}
+
+GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
+ /* .iface = */ {
+ /* .get_name = */ ggml_backend_metal_buffer_type_get_name,
+ /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
+ /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size,
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
+ /* .is_host = */ ggml_backend_metal_buffer_type_is_host,
+ },
+ /* .context = */ NULL,
+ };
+
+ return &ggml_backend_buffer_type_metal;
+}
+
+// buffer from ptr
+
+GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
+ struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
+
+ ctx->all_data = data;
+ ctx->all_size = size;
+ ctx->owned = false;
+ ctx->n_buffers = 0;
+
+ const size_t size_page = sysconf(_SC_PAGESIZE);
+
+ // page-align the data ptr
+ {
+ const uintptr_t offs = (uintptr_t) data % size_page;
+ data = (void *) ((char *) data - offs);
+ size += offs;
+ }
+
+ size_t size_aligned = size;
+ if ((size_aligned % size_page) != 0) {
+ size_aligned += (size_page - (size_aligned % size_page));
+ }
+
+ id<MTLDevice> device = ggml_backend_metal_get_device();
+
+ // the buffer fits into the max buffer size allowed by the device
+ if (size_aligned <= device.maxBufferLength) {
+ ctx->buffers[ctx->n_buffers].data = data;
+ ctx->buffers[ctx->n_buffers].size = size;
+
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
+ return false;
+ }
+
+ ggml_backend_metal_log_allocated_size(device, size_aligned);
+
+ ++ctx->n_buffers;
+ } else {
+ // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
+ // one of the views
+ const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
+ const size_t size_step = device.maxBufferLength - size_ovlp;
+ const size_t size_view = device.maxBufferLength;
+
+ for (size_t i = 0; i < size; i += size_step) {
+ const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
+
+ ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
+ ctx->buffers[ctx->n_buffers].size = size_step_aligned;
+
+ ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+ if (ctx->buffers[ctx->n_buffers].metal == nil) {
+ GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
+ return false;
+ }
+
+ ggml_backend_metal_log_allocated_size(device, size_step_aligned);
+
+ if (i + size_step < size) {
+ GGML_METAL_LOG_INFO("\n");
+ }
+
+ ++ctx->n_buffers;
+ }
+ }
+
+ return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
+}
+
+// backend
+
+GGML_CALL static const char * ggml_backend_metal_name(ggml_backend_t backend) {
+ return "Metal";
+
+ UNUSED(backend);
+}
+
+GGML_CALL static void ggml_backend_metal_free(ggml_backend_t backend) {
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+ ggml_metal_free(ctx);
+ free(backend);
+}
+
+GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
+ return ggml_backend_metal_buffer_type();
+
+ UNUSED(backend);
+}
+
+GGML_CALL static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+ struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
+
+ return ggml_metal_graph_compute(metal_ctx, cgraph);
+}
+
+GGML_CALL static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+ struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
+
+ return ggml_metal_supports_op(metal_ctx, op);
+}
+
+GGML_CALL static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+ return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
+
+ UNUSED(backend);
+}
+
+static struct ggml_backend_i ggml_backend_metal_i = {
+ /* .get_name = */ ggml_backend_metal_name,
+ /* .free = */ ggml_backend_metal_free,
+ /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
+ /* .set_tensor_async = */ NULL,
+ /* .get_tensor_async = */ NULL,
+ /* .cpy_tensor_async = */ NULL,
+ /* .synchronize = */ NULL,
+ /* .graph_plan_create = */ NULL,
+ /* .graph_plan_free = */ NULL,
+ /* .graph_plan_update = */ NULL,
+ /* .graph_plan_compute = */ NULL,
+ /* .graph_compute = */ ggml_backend_metal_graph_compute,
+ /* .supports_op = */ ggml_backend_metal_supports_op,
+ /* .supports_buft = */ ggml_backend_metal_supports_buft,
+ /* .offload_op = */ NULL,
+ /* .event_new = */ NULL,
+ /* .event_free = */ NULL,
+ /* .event_record = */ NULL,
+ /* .event_wait = */ NULL,
+ /* .event_synchronize = */ NULL,
+};
+
+void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
+ ggml_metal_log_callback = log_callback;
+ ggml_metal_log_user_data = user_data;
+}
+
+static ggml_guid_t ggml_backend_metal_guid(void) {
+ static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
+ return &guid;
+}
+
+ggml_backend_t ggml_backend_metal_init(void) {
+ struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
+
+ if (ctx == NULL) {
+ return NULL;
+ }
+
+ ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
+
+ *metal_backend = (struct ggml_backend) {
+ /* .guid = */ ggml_backend_metal_guid(),
+ /* .interface = */ ggml_backend_metal_i,
+ /* .context = */ ctx,
+ };
+
+ return metal_backend;
+}
+
+bool ggml_backend_is_metal(ggml_backend_t backend) {
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
+}
+
+void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
+ GGML_ASSERT(ggml_backend_is_metal(backend));
+
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
+}
+
+bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
+ GGML_ASSERT(ggml_backend_is_metal(backend));
+
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+
+ return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
+}
+
+void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
+ GGML_ASSERT(ggml_backend_is_metal(backend));
+
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+ ctx->should_capture_next_compute = true;
+}
+
+GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
+
+GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
+ return ggml_backend_metal_init();
+
+ GGML_UNUSED(params);
+ GGML_UNUSED(user_data);
+}
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
new file mode 100644
index 00000000..67dcf53d
--- /dev/null
+++ b/ggml/src/ggml-metal.metal
@@ -0,0 +1,6563 @@
+//
+// Copyright (C) 2023-2024 The ggml authors
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+
+#define GGML_COMMON_DECL_METAL
+#define GGML_COMMON_IMPL_METAL
+#include "ggml-common.h"
+
+#include <metal_stdlib>
+
+using namespace metal;
+
+#define MAX(x, y) ((x) > (y) ? (x) : (y))
+#define MIN(x, y) ((x) < (y) ? (x) : (y))
+#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
+
+#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
+
+enum ggml_sort_order {
+ GGML_SORT_ORDER_ASC,
+ GGML_SORT_ORDER_DESC,
+};
+
+// general-purpose kernel for addition, multiplication and division of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across all dims
+// cons: not very efficient
+kernel void kernel_add(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int64_t & offs,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig.z;
+ const int64_t i02 = tgpig.y;
+ const int64_t i01 = tgpig.x;
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int i10 = i0 % ne10;
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
+ }
+}
+
+kernel void kernel_mul(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig.z;
+ const int64_t i02 = tgpig.y;
+ const int64_t i01 = tgpig.x;
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int i10 = i0 % ne10;
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
+ }
+}
+
+kernel void kernel_div(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig.z;
+ const int64_t i02 = tgpig.y;
+ const int64_t i01 = tgpig.x;
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int i10 = i0 % ne10;
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
+ }
+}
+
+template<typename T>
+kernel void kernel_repeat(
+ device const char * src0,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i3 = tgpig.z;
+ const int64_t i2 = tgpig.y;
+ const int64_t i1 = tgpig.x;
+
+ const int64_t i03 = i3 % ne03;
+ const int64_t i02 = i2 % ne02;
+ const int64_t i01 = i1 % ne01;
+
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+ device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int i00 = i0 % ne00;
+ *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
+ }
+}
+
+typedef decltype(kernel_repeat<float>) kernel_repeat_t;
+
+template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
+template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
+template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
+template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
+
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_add_row(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ constant uint64_t & nb [[buffer(28)]],
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] + src1[tpig % nb];
+}
+kernel void kernel_add_4(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] + src1[tpig];
+}
+
+kernel void kernel_mul_row(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ constant uint64_t & nb [[buffer(28)]],
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
+}
+
+kernel void kernel_mul_4(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * src1[tpig];
+}
+
+kernel void kernel_div_row(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ constant uint64_t & nb [[buffer(28)]],
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] / src1[tpig % nb];
+}
+kernel void kernel_div_4(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] / src1[tpig];
+}
+
+kernel void kernel_scale(
+ device const float * src0,
+ device float * dst,
+ constant float & scale,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_scale_4(
+ device const float4 * src0,
+ device float4 * dst,
+ constant float & scale,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_clamp(
+ device const float * src0,
+ device float * dst,
+ constant float & min,
+ constant float & max,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
+}
+
+kernel void kernel_relu(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = max(0.0f, src0[tpig]);
+}
+
+kernel void kernel_sigmoid(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
+}
+
+kernel void kernel_tanh(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+ dst[tpig] = precise::tanh(x);
+}
+
+constant float GELU_COEF_A = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_4(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+
+ // BEWARE !!!
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+ // This was observed with Falcon 7B and 40B models
+ //
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_quick(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_gelu_quick_4(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_silu(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+ dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_silu_4(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+ dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_sqr(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * src0[tpig];
+}
+
+kernel void kernel_sum_rows(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tpig[[thread_position_in_grid]]) {
+ int64_t i3 = tpig.z;
+ int64_t i2 = tpig.y;
+ int64_t i1 = tpig.x;
+
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+ return;
+ }
+
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
+ device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
+
+ float row_sum = 0;
+
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
+ row_sum += src_row[i0];
+ }
+
+ dst_row[0] = row_sum;
+}
+
+template<typename T>
+kernel void kernel_soft_max(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant float & scale,
+ constant float & max_bias,
+ constant float & m0,
+ constant float & m1,
+ constant uint32_t & n_head_log2,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = (tgpig) / (ne02*ne01);
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+ device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
+ device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+
+ float slope = 1.0f;
+
+ // ALiBi
+ if (max_bias > 0.0f) {
+ const int64_t h = i02;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slope = pow(base, exp);
+ }
+
+ // parallel max
+ float lmax = -INFINITY;
+
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
+ }
+
+ // find the max value in the block
+ float max_val = simd_max(lmax);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = -INFINITY;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = max_val;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max_val = buf[tiisg];
+ max_val = simd_max(max_val);
+ }
+
+ // parallel sum
+ float lsum = 0.0f;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
+ lsum += exp_psrc0;
+ pdst[i00] = exp_psrc0;
+ }
+
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
+ float sum = simd_sum(lsum);
+
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[tiisg];
+ sum = simd_sum(sum);
+ }
+
+ const float inv_sum = 1.0f/sum;
+
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ pdst[i00] *= inv_sum;
+ }
+}
+
+template<typename T>
+kernel void kernel_soft_max_4(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant float & scale,
+ constant float & max_bias,
+ constant float & m0,
+ constant float & m1,
+ constant uint32_t & n_head_log2,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = (tgpig) / (ne02*ne01);
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+ device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
+ device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+
+ float slope = 1.0f;
+
+ if (max_bias > 0.0f) {
+ const int64_t h = i02;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slope = pow(base, exp);
+ }
+
+ // parallel max
+ float4 lmax4 = -INFINITY;
+
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
+ }
+
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
+
+ float max_val = simd_max(lmax);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = -INFINITY;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = max_val;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max_val = buf[tiisg];
+ max_val = simd_max(max_val);
+ }
+
+ // parallel sum
+ float4 lsum4 = 0.0f;
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
+ lsum4 += exp_psrc4;
+ pdst4[i00] = exp_psrc4;
+ }
+
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
+ float sum = simd_sum(lsum);
+
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[tiisg];
+ sum = simd_sum(sum);
+ }
+
+ const float inv_sum = 1.0f/sum;
+
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ pdst4[i00] *= inv_sum;
+ }
+}
+
+typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
+typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
+
+template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
+template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
+template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
+template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
+
+kernel void kernel_diag_mask_inf(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int & n_past,
+ uint3 tpig[[thread_position_in_grid]]) {
+ const int64_t i02 = tpig[2];
+ const int64_t i01 = tpig[1];
+ const int64_t i00 = tpig[0];
+
+ if (i00 > n_past + i01) {
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+ } else {
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+ }
+}
+
+kernel void kernel_diag_mask_inf_8(
+ device const float4 * src0,
+ device float4 * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int & n_past,
+ uint3 tpig[[thread_position_in_grid]]) {
+
+ const int64_t i = 2*tpig[0];
+
+ dst[i+0] = src0[i+0];
+ dst[i+1] = src0[i+1];
+ int64_t i4 = 4*i;
+ const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
+ const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
+ const int64_t i00 = i4;
+ for (int k = 3; k >= 0; --k) {
+ if (i00 + 4 + k <= n_past + i01) {
+ break;
+ }
+ dst[i+1][k] = -INFINITY;
+ if (i00 + k > n_past + i01) {
+ dst[i][k] = -INFINITY;
+ }
+ }
+}
+
+kernel void kernel_norm(
+ device const void * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant float & eps,
+ threadgroup float * sum [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
+ // MEAN
+ // parallel sum
+ sum[tpitg] = 0.0f;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ sum[tpitg] += x[i00];
+ }
+ // reduce
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ for (uint i = ntg/2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ sum[tpitg] += sum[tpitg + i];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ const float mean = sum[0] / ne00;
+
+ // recenter and VARIANCE
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ device float * y = dst + tgpig*ne00;
+ sum[tpitg] = 0.0f;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ y[i00] = x[i00] - mean;
+ sum[tpitg] += y[i00] * y[i00];
+ }
+
+ // reduce
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ for (uint i = ntg/2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ sum[tpitg] += sum[tpitg + i];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ const float variance = sum[0] / ne00;
+
+ const float scale = 1.0f/sqrt(variance + eps);
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ y[i00] = y[i00] * scale;
+ }
+}
+
+kernel void kernel_rms_norm(
+ device const void * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant float & eps,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
+
+ float4 sumf = 0;
+ float all_sum = 0;
+
+ // parallel sum
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ sumf += x[i00] * x[i00];
+ }
+ all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
+ all_sum = simd_sum(all_sum);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = all_sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ all_sum = buf[tiisg];
+ all_sum = simd_sum(all_sum);
+ }
+
+ const float mean = all_sum/ne00;
+ const float scale = 1.0f/sqrt(mean + eps);
+
+ device float4 * y = (device float4 *) (dst + tgpig*ne00);
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ y[i00] = x[i00] * scale;
+ }
+}
+
+kernel void kernel_group_norm(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int32_t & n_groups,
+ constant float & eps,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t ne = ne00*ne01*ne02;
+ const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
+
+ int start = tgpig * gs;
+ int end = start + gs;
+
+ start += tpitg;
+
+ if (end >= ne) {
+ end = ne;
+ }
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int j = start; j < end; j += ntg) {
+ tmp += src0[j];
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ tmp = simd_sum(tmp);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = tmp;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ tmp = buf[tiisg];
+ tmp = simd_sum(tmp);
+ }
+
+ const float mean = tmp / gs;
+ tmp = 0.0f;
+
+ for (int j = start; j < end; j += ntg) {
+ float xi = src0[j] - mean;
+ dst[j] = xi;
+ tmp += xi * xi;
+ }
+
+ tmp = simd_sum(tmp);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = tmp;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ tmp = buf[tiisg];
+ tmp = simd_sum(tmp);
+ }
+
+ const float variance = tmp / gs;
+ const float scale = 1.0f/sqrt(variance + eps);
+ for (int j = start; j < end; j += ntg) {
+ dst[j] *= scale;
+ }
+}
+
+// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+
+ float2 acc = 0.f;
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
+
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
+ }
+ return d * (sumy * -8.f + acc[0] + acc[1]);
+}
+
+// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+ float m = qb_curr->m;
+
+ float2 acc = 0.f;
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
+ }
+ return d * (acc[0] + acc[1]) + sumy * m;
+}
+
+// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+
+ float2 acc = 0.f;
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
+
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+ }
+ return d * (sumy * -16.f + acc[0] + acc[1]);
+}
+
+// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_1/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+ float m = qb_curr->m;
+
+ float2 acc = 0.f;
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
+
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+ }
+ return d * (acc[0] + acc[1]) + sumy * m;
+}
+
+// putting them in the kernel cause a significant performance penalty
+#define N_DST 4 // each SIMD group works on 4 rows
+#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
+//Note: This is a template, but strictly speaking it only applies to
+// quantizations where the block size is 32. It also does not
+// guard against the number of rows not being divisible by
+// N_DST, so this is another explicit assumption of the implementation.
+template<typename block_q_type, int nr, int nsg, int nw>
+void mul_vec_q_n_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig, uint tiisg, uint sgitg) {
+ const int nb = ne00/QK4_0;
+
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * nsg + sgitg) * nr;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q_type * x = (device const block_q_type *) src0 + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[16]; // src1 vector cache
+ float sumf[nr] = {0.f};
+
+ const int ix = (tiisg/2);
+ const int il = (tiisg%2)*8;
+
+ device const float * yb = y + ix * QK4_0 + il;
+
+ // each thread in a SIMD group deals with half a block.
+ for (int ib = ix; ib < nb; ib += nw/2) {
+ float sumy = 0;
+ for (int i = 0; i < 8; i += 2) {
+ sumy += yb[i] + yb[i+1];
+ yl[i+0] = yb[i+ 0];
+ yl[i+1] = yb[i+ 1]/256.f;
+
+ sumy += yb[i+16] + yb[i+17];
+ yl[i+8] = yb[i+16]/16.f;
+ yl[i+9] = yb[i+17]/4096.f;
+ }
+
+ for (int row = 0; row < nr; row++) {
+ sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
+ }
+
+ yb += QK4_0 * 16;
+ }
+
+ for (int row = 0; row < nr; ++row) {
+ const float tot = simd_sum(sumf[row]);
+ if (tiisg == 0 && first_row + row < ne01) {
+ dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
+ }
+ }
+}
+
+kernel void kernel_mul_mv_q4_0_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+}
+
+kernel void kernel_mul_mv_q4_1_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+}
+
+kernel void kernel_mul_mv_q5_0_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+}
+
+kernel void kernel_mul_mv_q5_1_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+}
+
+
+#define NB_Q8_0 8
+
+void kernel_mul_mv_q8_0_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+ const int nr = N_DST;
+ const int nsg = N_SIMDGROUP;
+ const int nw = N_SIMDWIDTH;
+
+ const int nb = ne00/QK8_0;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * nsg + sgitg) * nr;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[NB_Q8_0];
+ float sumf[nr]={0.f};
+
+ const int ix = tiisg/4;
+ const int il = tiisg%4;
+
+ device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
+
+ // each thread in a SIMD group deals with NB_Q8_0 quants at a time
+ for (int ib = ix; ib < nb; ib += nw/4) {
+ for (int i = 0; i < NB_Q8_0; ++i) {
+ yl[i] = yb[i];
+ }
+
+ for (int row = 0; row < nr; row++) {
+ device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
+ float sumq = 0.f;
+ for (int iq = 0; iq < NB_Q8_0; ++iq) {
+ sumq += qs[iq] * yl[iq];
+ }
+ sumf[row] += sumq*x[ib+row*nb].d;
+ }
+
+ yb += NB_Q8_0 * nw;
+ }
+
+ for (int row = 0; row < nr; ++row) {
+ const float tot = simd_sum(sumf[row]);
+ if (tiisg == 0 && first_row + row < ne01) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_q8_0_f32")]]
+kernel void kernel_mul_mv_q8_0_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+}
+
+#define N_MV_T_T 4
+
+template<typename T0, typename T04, typename T1, typename T14>
+void kernel_mul_mv_impl(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ int64_t ne10,
+ int64_t ne11,
+ int64_t ne12,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ uint3 tgpig,
+ uint tiisg) {
+ const int64_t r0 = tgpig.x;
+ const int64_t rb = tgpig.y*N_MV_T_T;
+ const int64_t im = tgpig.z;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+ device const T0 * x = (device const T0 *) (src0 + offset0);
+
+ if (ne00 < 128) {
+ for (int row = 0; row < N_MV_T_T; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00; i += 32) {
+ sumf += (T0) x[i] * (T1) y[i];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+ } else {
+ device const T04 * x4 = (device const T04 *) x;
+ for (int row = 0; row < N_MV_T_T; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
+ device const T14 * y4 = (device const T14 *) y;
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+ }
+}
+
+template<typename T0, typename T04, typename T1, typename T14>
+kernel void kernel_mul_mv(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_impl<T0, T04, T1, T14>(
+ src0,
+ src1,
+ dst,
+ ne00,
+ ne01,
+ ne02,
+ nb00,
+ nb01,
+ nb02,
+ ne10,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg);
+}
+
+typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
+
+template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
+template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
+template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
+
+template<typename T, typename T4>
+kernel void kernel_mul_mv_1row(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+
+ const int64_t r0 = tgpig.x;
+ const int64_t r1 = tgpig.y;
+ const int64_t im = tgpig.z;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+ device const T * x = (device const T *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+ float sumf = 0;
+ if (ne00 < 128) {
+ for (int i = tiisg; i < ne00; i += 32) {
+ sumf += (float) x[i] * (float) y[i];
+ }
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ } else {
+ device const T4 * x4 = (device const T4 *) x;
+ device const float4 * y4 = (device const float4 *) y;
+
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
+ }
+
+ float all_sum = simd_sum(sumf);
+
+ if (tiisg == 0) {
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+}
+
+typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
+
+template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
+
+// Assumes row size (ne00) is a multiple of 4
+template<typename T, typename T4>
+kernel void kernel_mul_mv_l4(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+
+ const int nrows = ne11;
+ const int64_t r0 = tgpig.x;
+ const int64_t im = tgpig.z;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+ device const T4 * x4 = (device const T4 *) (src0 + offset0);
+
+ for (int r1 = 0; r1 < nrows; ++r1) {
+ device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+}
+
+typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
+
+template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
+
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
+ return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static inline void rope_yarn(
+ float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+ thread float * cos_theta, thread float * sin_theta) {
+ // Get n-d rotational scaling corrected for extrapolation
+ float theta_interp = freq_scale * theta_extrap;
+ float theta = theta_interp;
+ if (ext_factor != 0.0f) {
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+ // Get n-d magnitude scaling corrected for interpolation
+ mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
+ }
+ *cos_theta = cos(theta) * mscale;
+ *sin_theta = sin(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
+ return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
+}
+
+static void rope_yarn_corr_dims(
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+ // start and end correction dims
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
+}
+
+template<typename T>
+kernel void kernel_rope_norm(
+ device const void * src0,
+ device const int32_t * src1,
+ device const float * src2,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int & n_past,
+ constant int & n_dims,
+ constant int & n_ctx_orig,
+ constant float & freq_base,
+ constant float & freq_scale,
+ constant float & ext_factor,
+ constant float & attn_factor,
+ constant float & beta_fast,
+ constant float & beta_slow,
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg[[threads_per_threadgroup]],
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
+ const int64_t i3 = tgpig[2];
+ const int64_t i2 = tgpig[1];
+ const int64_t i1 = tgpig[0];
+
+ float corr_dims[2];
+ rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+ device const int32_t * pos = src1;
+
+ const float theta_base = (float) pos[i2];
+ const float inv_ndims = -1.f/n_dims;
+
+ float cos_theta;
+ float sin_theta;
+
+ for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
+ if (i0 < n_dims) {
+ const int64_t ic = i0/2;
+
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ const float x0 = src[0];
+ const float x1 = src[1];
+
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
+ } else {
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ dst_data[0] = src[0];
+ dst_data[1] = src[1];
+ }
+ }
+}
+
+template<typename T>
+kernel void kernel_rope_neox(
+ device const void * src0,
+ device const int32_t * src1,
+ device const float * src2,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int & n_past,
+ constant int & n_dims,
+ constant int & n_ctx_orig,
+ constant float & freq_base,
+ constant float & freq_scale,
+ constant float & ext_factor,
+ constant float & attn_factor,
+ constant float & beta_fast,
+ constant float & beta_slow,
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg[[threads_per_threadgroup]],
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
+ const int64_t i3 = tgpig[2];
+ const int64_t i2 = tgpig[1];
+ const int64_t i1 = tgpig[0];
+
+ float corr_dims[2];
+ rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+ device const int32_t * pos = src1;
+
+ const float theta_base = (float) pos[i2];
+ const float inv_ndims = -1.f/n_dims;
+
+ float theta = theta_base * pow(freq_base, 2*tiitg*inv_ndims);
+ const float theta_multiplier = pow(freq_base, 2*tptg.x*inv_ndims);
+
+ float cos_theta;
+ float sin_theta;
+
+ int64_t i0 = 2*tiitg;
+ for ( ; i0 < n_dims; i0 += 2*tptg.x) {
+ const int64_t ic = i0/2;
+
+ const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
+
+ const float x0 = src[0];
+ const float x1 = src[n_dims/2];
+
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+
+ theta *= theta_multiplier;
+ }
+ for ( ; i0 < ne0; i0 += 2*tptg.x) {
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ dst_data[0] = src[0];
+ dst_data[1] = src[1];
+ }
+
+ // Original version
+ //for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
+ // if (i0 < n_dims) {
+ // const int64_t ic = i0/2;
+
+ // // Who thought that having a pow() evaluation in a loop is a good idea?
+ // //const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+
+ // const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+ // rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+ // device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+ // device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
+
+ // const float x0 = src[0];
+ // const float x1 = src[n_dims/2];
+
+ // dst_data[0] = x0*cos_theta - x1*sin_theta;
+ // dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+
+ // theta *= theta_multiplier;
+ // } else {
+ // device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ // device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ // dst_data[0] = src[0];
+ // dst_data[1] = src[1];
+ // }
+ //}
+}
+
+typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
+typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
+
+template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
+template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
+
+template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
+template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
+
+typedef void (im2col_t)(
+ device const float * x,
+ device char * dst,
+ constant int32_t & ofs0,
+ constant int32_t & ofs1,
+ constant int32_t & IW,
+ constant int32_t & IH,
+ constant int32_t & CHW,
+ constant int32_t & s0,
+ constant int32_t & s1,
+ constant int32_t & p0,
+ constant int32_t & p1,
+ constant int32_t & d0,
+ constant int32_t & d1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]);
+
+template <typename T>
+kernel void kernel_im2col(
+ device const float * x,
+ device char * dst,
+ constant int32_t & ofs0,
+ constant int32_t & ofs1,
+ constant int32_t & IW,
+ constant int32_t & IH,
+ constant int32_t & CHW,
+ constant int32_t & s0,
+ constant int32_t & s1,
+ constant int32_t & p0,
+ constant int32_t & p1,
+ constant int32_t & d0,
+ constant int32_t & d1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
+ const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
+
+ const int32_t offset_dst =
+ (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
+ (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
+
+ device T * pdst = (device T *) (dst);
+
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+ pdst[offset_dst] = 0.0f;
+ } else {
+ const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
+ pdst[offset_dst] = x[offset_src + iih * IW + iiw];
+ }
+}
+
+template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
+template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
+
+kernel void kernel_upscale_f32(
+ device const char * src0,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant float & sf0,
+ constant float & sf1,
+ constant float & sf2,
+ constant float & sf3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const int64_t i3 = tgpig.z;
+ const int64_t i2 = tgpig.y;
+ const int64_t i1 = tgpig.x;
+
+ const int64_t i03 = i3/sf3;
+ const int64_t i02 = i2/sf2;
+ const int64_t i01 = i1/sf1;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int64_t i00 = i0/sf0;
+
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ dst_ptr[0] = src0_ptr[0];
+ }
+}
+
+kernel void kernel_pad_f32(
+ device const char * src0,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const int64_t i3 = tgpig.z;
+ const int64_t i2 = tgpig.y;
+ const int64_t i1 = tgpig.x;
+
+ const int64_t i03 = i3;
+ const int64_t i02 = i2;
+ const int64_t i01 = i1;
+
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
+
+ if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ if (i0 < ne00) {
+ dst_ptr[i0] = src0_ptr[i0];
+ } else {
+ dst_ptr[i0] = 0.0f;
+ }
+ }
+
+ return;
+ }
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ dst_ptr[i0] = 0.0f;
+ }
+}
+
+kernel void kernel_arange_f32(
+ device char * dst,
+ constant int64_t & ne0,
+ constant float & start,
+ constant float & step,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ device float * dst_ptr = (device float *) dst;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ dst_ptr[i0] = start + step * i0;
+ }
+}
+
+kernel void kernel_timestep_embedding_f32(
+ device const char * src0,
+ device char * dst,
+ constant uint64_t & nb1,
+ constant int & dim,
+ constant int & max_period,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ int i = tgpig.x;
+ device float * embed_data = (device float *)(dst + i*nb1);
+
+ int half_ = dim / 2;
+ for (int j = tpitg.x; j < half_; j += ntg.x) {
+ float timestep = ((device float *)src0)[i];
+ float freq = (float)exp(-log((float)max_period) * j / half_);
+ float arg = timestep * freq;
+ embed_data[j ] = cos(arg);
+ embed_data[j + half_] = sin(arg);
+ }
+
+ if (dim % 2 != 0 && tpitg.x == 0) {
+ embed_data[dim] = 0.f;
+ }
+}
+
+// bitonic sort implementation following the CUDA kernels as reference
+typedef void (argsort_t)(
+ device const float * x,
+ device int32_t * dst,
+ constant int64_t & ncols,
+ constant int64_t & ncols_pad,
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]);
+
+template<ggml_sort_order order>
+kernel void kernel_argsort_f32_i32(
+ device const float * x,
+ device int32_t * dst,
+ constant int64_t & ncols,
+ constant int64_t & ncols_pad,
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
+ // bitonic sort
+ int col = tpitg[0];
+ int row = tgpig[1];
+
+ if (col >= ncols_pad) return;
+
+ device const float * x_row = x + row * ncols;
+ threadgroup int32_t * dst_row = shared_values;
+
+ // initialize indices
+ dst_row[col] = col;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int k = 2; k <= ncols_pad; k *= 2) {
+ for (int j = k / 2; j > 0; j /= 2) {
+ int ixj = col ^ j;
+ if (ixj > col) {
+ if ((col & k) == 0) {
+ if (dst_row[col] >= ncols ||
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+ ) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ } else {
+ if (dst_row[ixj] >= ncols ||
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+ ) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ }
+
+ // copy the result to dst without the padding
+ if (col < ncols) {
+ dst[row * ncols + col] = dst_row[col];
+ }
+}
+
+template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
+template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
+
+kernel void kernel_leaky_relu_f32(
+ device const float * src0,
+ device float * dst,
+ constant float & slope,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
+}
+
+typedef void (flash_attn_ext_f16_t)(
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device float * dst,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant uint64_t & nb21,
+ constant uint64_t & nb22,
+ constant uint64_t & nb23,
+ constant uint64_t & nb31,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant float & scale,
+ constant float & max_bias,
+ constant float & m0,
+ constant float & m1,
+ constant uint32_t & n_head_log2,
+ threadgroup half * shared,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]);
+
+// ref: https://arxiv.org/pdf/2307.08691.pdf
+template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
+kernel void kernel_flash_attn_ext_f16(
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device float * dst,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant uint64_t & nb21,
+ constant uint64_t & nb22,
+ constant uint64_t & nb23,
+ constant uint64_t & nb31,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant float & scale,
+ constant float & max_bias,
+ constant float & m0,
+ constant float & m1,
+ constant uint32_t & n_head_log2,
+ threadgroup half * shared [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const short nsg = ntg.y; // number of simdgroups
+
+ const short iq3 = tgpig[2];
+ const short iq2 = tgpig[1];
+ const short iq1 = tgpig[0]*Q;
+
+ const short D4 = D/4;
+ const short D8 = D/8;
+ //const short Q8 = Q/8;
+ const short NW = N_SIMDWIDTH;
+ const short SH = (C + Q); // shared memory per simdgroup in (half)
+
+ const short T = D + 2*nsg*SH; // shared memory size per query in (half)
+ const short TF = T/2; // shared memory size per query in (float)
+ const short T4 = T/4; // shared memory size per query in (half4)
+
+ threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
+ threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+
+ // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+ simdgroup_half8x8 lo[D8];
+
+ // load heads from Q to shared memory
+ for (short j = sgitg; j < Q; j += nsg) {
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
+
+ for (short i = tiisg; i < D4; i += NW) {
+ if (iq1 + j < ne01) {
+ sq4[j*T4 + i] = (half4) q4[i];
+ } else {
+ sq4[j*T4 + i] = 0.0h;
+ }
+ }
+ }
+
+ // zero out lo
+ for (short i = 0; i < D8; ++i) {
+ lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
+ }
+
+ // zero out shared memory SH
+ for (short j = 0; j < Q; ++j) {
+ for (short i = tiisg; i < SH; i += NW) {
+ ss[j*TF + i] = 0.0f;
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ {
+ float S[Q] = { [0 ... Q-1] = 0.0h };
+ float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
+
+ // assume K and V are same shape
+ const short ne22 = ne12;
+ const short ne23 = ne13;
+
+ // broadcast
+ const short rk2 = ne02/ne12;
+ const short rk3 = ne03/ne13;
+
+ const short rv2 = ne02/ne22;
+ const short rv3 = ne03/ne23;
+
+ // k indices
+ const short ik2 = iq2/rk2;
+ const short ik3 = iq3/rk3;
+
+ // v indices
+ const short iv2 = iq2/rv2;
+ const short iv3 = iq3/rv3;
+
+ // load the queries from shared memory into local memory
+ simdgroup_half8x8 mq[D8];
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_load(mq[i], sq + i*8, T);
+ }
+
+ // pointer to the mask
+ device const half * mp = (device const half *) (mask + iq1*nb31);
+
+ float slope = 1.0f;
+
+ // ALiBi
+ if (max_bias > 0.0f) {
+ const uint32_t h = iq2;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slope = pow(base, exph);
+ }
+
+ // loop over the KV cache
+ // each simdgroup handles blocks of Q rows and C columns
+ for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+ const int ic = ic0 + C*sgitg;
+ if (ic >= ne11) {
+ break;
+ }
+
+ // Q*K^T
+ {
+ for (short cc = 0; cc < C/8; ++cc) {
+ simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
+
+ device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_half8x8 mk;
+ simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
+
+ simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+ }
+
+ simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
+
+ const short tx = tiisg%4;
+ const short ty = tiisg/4;
+
+ if (mask != q) {
+ // mqk = mqk*scale + mask*slope
+ ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
+ ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
+ } else {
+ // mqk = mqk*scale
+ ss[8*cc + ty*TF + 2*tx + 0] *= scale;
+ ss[8*cc + ty*TF + 2*tx + 1] *= scale;
+ }
+ }
+ }
+
+ // used to detect blocks full of -INF
+ float smax = -INFINITY;
+
+ // online softmax
+ {
+ float ms[Q];
+
+ for (short j = 0; j < Q; ++j) {
+ const short p = tiisg;
+
+ const float m = M[j];
+ const float s = ss[j*TF + p];
+
+ smax = simd_max(max(smax, s));
+ M[j] = simd_max(max(M[j], s));
+
+ ms[j] = exp(m - M[j]);
+ const float vs = exp(s - M[j]);
+
+ S[j] = S[j]*ms[j] + simd_sum(vs);
+
+ // the P matrix from the paper (Q rows, C columns)
+ ss[j*TF + p] = vs;
+ }
+
+ // create a QxQ diagonal matrix for rescaling the output
+ if (tiisg < Q) {
+ ss[tiisg*TF + C + tiisg] = ms[tiisg];
+ }
+ }
+
+ // skip -INF blocks
+ if (smax == -INFINITY) {
+ continue;
+ }
+
+ // O = diag(ms)*O
+ {
+ simdgroup_float8x8 mm;
+ simdgroup_load(mm, ss + C, TF, 0, false);
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_multiply(lo[i], mm, lo[i]);
+ }
+ }
+
+ // O = O + (Q*K^T)*V
+ {
+ for (short cc = 0; cc < C/8; ++cc) {
+ device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_half8x8 mk;
+ simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
+
+ simdgroup_float8x8 mv;
+ simdgroup_load(mv, ss + 8*cc, TF, 0, false);
+
+ simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
+ }
+ }
+ }
+ }
+
+ // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+ for (short j = 0; j < Q; ++j) {
+ if (tiisg == 0) {
+ ss[j*TF + 0] = S[j];
+ ss[j*TF + 1] = M[j];
+ }
+ }
+ }
+
+ // reduce the warps sequentially
+ for (short sg = 1; sg < nsg; ++sg) {
+ float S = { 0.0h };
+ float M = { -FLT_MAX/2 };
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // each simdgroup stores its output to shared memory, reusing sq
+ if (sgitg == sg) {
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_store(lo[i], sq + i*8, T, 0, false);
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // the first simdgroup accumulates the results from the other simdgroups
+ if (sgitg == 0) {
+ for (short j = 0; j < Q; ++j) {
+ const float S0 = ss[j*TF + 0];
+ const float S1 = ss[j*TF + sg*SH + 0];
+
+ const float M0 = ss[j*TF + 1];
+ const float M1 = ss[j*TF + sg*SH + 1];
+
+ M = max(M0, M1);
+
+ const float ms0 = exp(M0 - M);
+ const float ms1 = exp(M1 - M);
+
+ S = S0*ms0 + S1*ms1;
+
+ if (tiisg == 0) {
+ ss[j*TF + 0] = S;
+ ss[j*TF + 1] = M;
+
+ ss[j*TF + C + j ] = ms0;
+ ss[j*TF + C + j + sg*SH] = ms1;
+ }
+ }
+
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+ {
+ simdgroup_half8x8 t;
+ simdgroup_float8x8 ms0;
+ simdgroup_float8x8 ms1;
+
+ simdgroup_load(ms0, ss + C, TF, 0, false);
+ simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
+
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_load (t, sq + i*8, T, 0, false);
+ simdgroup_multiply(t, ms1, t);
+
+ simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
+ }
+ }
+ }
+ }
+
+ // store result to shared memory (reuse sq)
+ if (sgitg == 0) {
+ for (short i = 0; i < D8; ++i) {
+ simdgroup_store(lo[i], sq + i*8, T, 0, false);
+ }
+ }
+
+ device float4 * dst4 = (device float4 *) dst;
+
+ // final rescale with 1/S and store to global memory
+ if (sgitg == 0) {
+ for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
+ const float S = ss[j*TF + 0];
+
+ for (short i = tiisg; i < D4; i += NW) {
+ dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
+ }
+ }
+ }
+}
+
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
+//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
+
+template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec_f16(
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device float * dst,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant uint64_t & nb21,
+ constant uint64_t & nb22,
+ constant uint64_t & nb23,
+ constant uint64_t & nb31,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant float & scale,
+ constant float & max_bias,
+ constant float & m0,
+ constant float & m1,
+ constant uint32_t & n_head_log2,
+ threadgroup half * shared [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const short nsg = ntg.y; // number of simdgroups
+
+ const short iq3 = tgpig[2];
+ const short iq2 = tgpig[1];
+ const short iq1 = tgpig[0];
+
+ const short D4 = D/4;
+ const short NW = N_SIMDWIDTH;
+ const short SH = (C + Q); // shared memory per simdgroup in (half)
+
+ const short T = D + 2*nsg*SH; // shared memory size per query in (half)
+
+ float slope = 1.0f;
+
+ // ALiBi
+ if (max_bias > 0.0f) {
+ const uint32_t h = iq2;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slope = pow(base, exp);
+ }
+
+ //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
+ threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+ threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
+ threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
+
+ // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+ half4 lo[D4/NW];
+
+ // load heads from Q to shared memory
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
+
+ for (short i = tiisg; i < D4; i += NW) {
+ if (iq1 < ne01) {
+ sq4[i] = (half4) q4[i];
+ } else {
+ sq4[i] = 0.0h;
+ }
+ }
+
+ // zero out lo
+ for (short i = tiisg; i < D4; i += NW) {
+ lo[i/NW] = 0.0h;
+ }
+
+ // zero out shared memory SH
+ for (short i = tiisg; i < SH/4; i += NW) {
+ ss4[i] = 0.0h;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ {
+ float S = { 0.0h };
+ float M = { -FLT_MAX/2 };
+
+ // assume K and V are same shape
+ const short ne22 = ne12;
+ const short ne23 = ne13;
+
+ // broadcast
+ const short rk2 = ne02/ne12;
+ const short rk3 = ne03/ne13;
+
+ const short rv2 = ne02/ne22;
+ const short rv3 = ne03/ne23;
+
+ // k indices
+ const short ik2 = iq2 / rk2;
+ const short ik3 = iq3 / rk3;
+
+ // v indices
+ const short iv2 = iq2 / rv2;
+ const short iv3 = iq3 / rv3;
+
+ // load the queries from shared memory into local memory
+ half4 mq[D4];
+
+ for (short ii = 0; ii < D4; ii += NW) {
+ short i = ii + tiisg;
+ mq[i] = sq4[i];
+ }
+
+ // pointer to the mask
+ device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
+
+ // loop over the KV cache
+ // each simdgroup handles blocks of Q rows and C columns
+ for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+ const int ic = ic0 + C*sgitg;
+ if (ic >= ne11) {
+ break;
+ }
+
+ // Q*K^T
+ {
+#pragma unroll
+ for (short cc = 0; cc < C/4; ++cc) {
+ float4 mqk = { 0.0h };
+
+ device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
+
+#pragma unroll
+ for (short ii = 0; ii < D4; ii += NW) {
+ const short i = ii + tiisg;
+
+ half4x4 mk;
+ mk[0] = pk4[i + 0*(nb11/8)];
+ mk[1] = pk4[i + 1*(nb11/8)];
+ mk[2] = pk4[i + 2*(nb11/8)];
+ mk[3] = pk4[i + 3*(nb11/8)];
+
+ mqk += (float4) (mq[i] * mk);
+ }
+
+ // reduce the results from the threads in the simdgroup
+ mqk += simd_shuffle_down(mqk, 16);
+ mqk += simd_shuffle_down(mqk, 8);
+ mqk += simd_shuffle_down(mqk, 4);
+ mqk += simd_shuffle_down(mqk, 2);
+ mqk += simd_shuffle_down(mqk, 1);
+
+ // mqk = mqk*scale + mask*slope
+ if (tiisg == 0) {
+ mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
+
+ ss4[cc] = mqk;
+ }
+ }
+ }
+
+ // online softmax
+ {
+ const short p = tiisg;
+
+ const float m = M;
+ const float s = ss[p];
+
+ M = simd_max(max(M, s));
+
+ const float ms = exp(m - M);
+ const float vs = exp(s - M);
+
+ S = S*ms + simd_sum(vs);
+
+ // the P matrix from the paper (Q rows, C columns)
+ ss[p] = vs;
+
+ // O = diag(ms)*O
+#pragma unroll
+ for (short ii = 0; ii < D4; ii += NW) {
+ const short i = ii + tiisg;
+ lo[i/NW] *= ms;
+ }
+ }
+
+ // O = O + (Q*K^T)*V
+ {
+#pragma unroll
+ for (short cc = 0; cc < C/4; ++cc) {
+ device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
+
+#pragma unroll
+ for (short ii = 0; ii < D4; ii += NW) {
+ const short i = ii + tiisg;
+
+ lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
+ lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
+ lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
+ lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
+ }
+ }
+ }
+
+ }
+
+ // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+ if (tiisg == 0) {
+ ss[0] = S;
+ ss[1] = M;
+ }
+ }
+
+ // store results to shared memory
+ for (short ii = 0; ii < D4; ii += NW) {
+ short i = ii + tiisg;
+ sr4[i] = lo[ii/NW];
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // parallel reduce
+ for (short r = nsg/2; r > 0; r >>= 1) {
+ if (sgitg < r) {
+ const float S0 = ss[ 0];
+ const float S1 = ss[r*SH + 0];
+
+ const float M0 = ss[ 1];
+ const float M1 = ss[r*SH + 1];
+
+ const float M = max(M0, M1);
+
+ const float ms0 = exp(M0 - M);
+ const float ms1 = exp(M1 - M);
+
+ const float S = S0*ms0 + S1*ms1;
+
+ if (tiisg == 0) {
+ ss[0] = S;
+ ss[1] = M;
+ }
+
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+ for (short ii = 0; ii < D4; ii += NW) {
+ short i = ii + tiisg;
+ sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ device float4 * dst4 = (device float4 *) dst;
+
+ // final rescale with 1/S and store to global memory
+ if (sgitg == 0) {
+ const float S = ss[0];
+
+ for (short ii = 0; ii < D4; ii += NW) {
+ short i = ii + tiisg;
+ dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
+ }
+ }
+}
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
+//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
+
+template<typename T0, typename T1>
+kernel void kernel_cpy(
+ device const void * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+ device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ dst_data[i00] = (T1) src[0];
+ }
+}
+
+typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
+
+template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
+template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
+template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
+template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
+
+kernel void kernel_cpy_f32_q8_0(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
+
+ device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_0; j++) {
+ const float v = src[j];
+ amax = MAX(amax, fabs(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK8_0].d = d;
+
+ for (int j = 0; j < QK8_0; ++j) {
+ const float x0 = src[j]*id;
+
+ dst_data[i00/QK8_0].qs[j] = round(x0);
+ }
+ }
+}
+
+kernel void kernel_cpy_f32_q4_0(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
+
+ device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < QK4_0; j++) {
+ const float v = src[j];
+ if (amax < fabs(v)) {
+ amax = fabs(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -8;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK4_0].d = d;
+
+ for (int j = 0; j < QK4_0/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK4_0/2 + j]*id;
+
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+ dst_data[i00/QK4_0].qs[j] = xi0;
+ dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
+ }
+ }
+}
+
+kernel void kernel_cpy_f32_q4_1(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
+
+ device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float min = FLT_MAX;
+ float max = -FLT_MAX;
+
+ for (int j = 0; j < QK4_1; j++) {
+ const float v = src[j];
+ if (min > v) min = v;
+ if (max < v) max = v;
+ }
+
+ const float d = (max - min) / ((1 << 4) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK4_1].d = d;
+ dst_data[i00/QK4_1].m = min;
+
+ for (int j = 0; j < QK4_1/2; ++j) {
+ const float x0 = (src[0 + j] - min)*id;
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
+
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+ dst_data[i00/QK4_1].qs[j] = xi0;
+ dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
+ }
+ }
+}
+
+kernel void kernel_cpy_f32_q5_0(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
+
+ device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < QK5_0; j++) {
+ const float v = src[j];
+ if (amax < fabs(v)) {
+ amax = fabs(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -16;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK5_0].d = d;
+
+ uint32_t qh = 0;
+ for (int j = 0; j < QK5_0/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK5_0/2 + j]*id;
+
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+ dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+ }
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+ for (int j = 0; j < 4; ++j) {
+ dst_data[i00/QK5_0].qh[j] = qh8[j];
+ }
+ }
+}
+
+kernel void kernel_cpy_f32_q5_1(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
+
+ device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float max = src[0];
+ float min = src[0];
+
+ for (int j = 1; j < QK5_1; j++) {
+ const float v = src[j];
+ min = v < min ? v : min;
+ max = v > max ? v : max;
+ }
+
+ const float d = (max - min) / 31;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst_data[i00/QK5_1].d = d;
+ dst_data[i00/QK5_1].m = min;
+
+ uint32_t qh = 0;
+ for (int j = 0; j < QK5_1/2; ++j) {
+ const float x0 = (src[0 + j] - min)*id;
+ const float x1 = (src[QK5_1/2 + j] - min)*id;
+
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+ dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+ }
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+ for (int j = 0; j < 4; ++j) {
+ dst_data[i00/QK5_1].qh[j] = qh8[j];
+ }
+ }
+}
+
+static inline int best_index_int8(int n, constant float * val, float x) {
+ if (x <= val[0]) return 0;
+ if (x >= val[n-1]) return n-1;
+ int ml = 0, mu = n-1;
+ while (mu-ml > 1) {
+ int mav = (ml+mu)/2;
+ if (x < val[mav]) mu = mav; else ml = mav;
+ }
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+constexpr constant static float kvalues_iq4nl_f[16] = {
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
+};
+
+kernel void kernel_cpy_f32_iq4_nl(
+ device const float * src0,
+ device void * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
+
+ device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < QK4_0; j++) {
+ const float v = src[j];
+ if (amax < fabs(v)) {
+ amax = fabs(v);
+ max = v;
+ }
+ }
+
+ const float d = max / kvalues_iq4nl_f[0];
+ const float id = d ? 1.0f/d : 0.0f;
+
+ float sumqx = 0, sumq2 = 0;
+ for (int j = 0; j < QK4_NL/2; ++j) {
+ const float x0 = src[0 + j]*id;
+ const float x1 = src[QK4_NL/2 + j]*id;
+
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
+
+ dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
+
+ const float v0 = kvalues_iq4nl_f[xi0];
+ const float v1 = kvalues_iq4nl_f[xi1];
+ const float w0 = src[0 + j]*src[0 + j];
+ const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
+ sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
+ sumq2 += w0*v0*v0 + w1*v1*v1;
+
+ }
+
+ dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
+
+ }
+}
+
+kernel void kernel_concat(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int32_t & dim,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const int64_t i3 = tgpig.z;
+ const int64_t i2 = tgpig.y;
+ const int64_t i1 = tgpig.x;
+
+ int64_t o[4] = {0, 0, 0, 0};
+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
+
+ device const float * x;
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
+ } else {
+ x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
+ }
+
+ device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ *y = *x;
+ }
+}
+
+void kernel_mul_mv_q2_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int step = sizeof(block_q2_K) * nb;
+
+ const int ix = tiisg/8; // 0...3
+ const int it = tiisg%8; // 0...7
+ const int iq = it/4; // 0 or 1
+ const int ir = it%4; // 0...3
+ const int is = (8*ir)/16;// 0 or 1
+
+ device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
+
+ for (int ib = ix; ib < nb; ib += 4) {
+
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+ yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
+ yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
+ yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
+ }
+
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
+ device const half * dh = &x[ib].d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; i += 2) {
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+ }
+ float dall = dh[0];
+ float dmin = dh[1] * 1.f/16.f;
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
+ dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
+
+ qs += step/2;
+ sc += step;
+ dh += step/2;
+ }
+
+ y4 += 4 * QK_K;
+ }
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_q2_K_f32")]]
+kernel void kernel_mul_mv_q2_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q3_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const int nb = ne00/QK_K;
+
+ const int64_t r0 = tgpig.x;
+ const int64_t r1 = tgpig.y;
+ const int64_t im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+
+ //const uint16_t kmask1 = 0x3030;
+ //const uint16_t kmask2 = 0x0f0f;
+
+ const int tid = tiisg/4;
+ const int ix = tiisg%4;
+ const int ip = tid/4; // 0 or 1
+ const int il = 2*((tid%4)/2); // 0 or 2
+ const int ir = tid%2;
+ const int n = 8;
+ const int l0 = n*ir;
+
+ // One would think that the Metal compiler would figure out that ip and il can only have
+ // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
+ // with these two tales.
+ //
+ // Possible masks for the high bit
+ const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
+ {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
+ {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
+ {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
+
+ // Possible masks for the low 2 bits
+ const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
+
+ const ushort4 hm = mm[2*ip + il/2];
+
+ const int shift = 2*il;
+ const float v1 = il == 0 ? 4.f : 64.f;
+ const float v2 = 4.f * v1;
+
+ const uint16_t s_shift1 = 4*ip;
+ const uint16_t s_shift2 = s_shift1 + il;
+
+ const int q_offset = 32*ip + l0;
+ const int y_offset = 128*ip + 32*il + l0;
+
+ const int step = sizeof(block_q3_K) * nb / 2;
+
+ device const float * y1 = yy + ix*QK_K + y_offset;
+
+ uint32_t scales32, aux32;
+ thread uint16_t * scales16 = (thread uint16_t *)&scales32;
+ thread const int8_t * scales = (thread const int8_t *)&scales32;
+
+ float sumf1[2] = {0.f};
+ float sumf2[2] = {0.f};
+ for (int i = ix; i < nb; i += 4) {
+
+ for (int l = 0; l < 8; ++l) {
+ yl[l+ 0] = y1[l+ 0];
+ yl[l+ 8] = y1[l+16];
+ yl[l+16] = y1[l+32];
+ yl[l+24] = y1[l+48];
+ }
+
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
+ device const uint16_t * a = (device const uint16_t *)(x[i].scales);
+ device const half * dh = &x[i].d;
+
+ for (int row = 0; row < 2; ++row) {
+
+ const float d_all = (float)dh[0];
+
+ scales16[0] = a[4];
+ scales16[1] = a[5];
+ aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
+ scales16[0] = a[il+0];
+ scales16[1] = a[il+1];
+ scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
+
+ float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
+ for (int l = 0; l < n; l += 2) {
+ const int32_t qs = q[l/2];
+ s1 += yl[l+0] * (qs & qm[il/2][0]);
+ s2 += yl[l+1] * (qs & qm[il/2][1]);
+ s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
+ s4 += yl[l+16] * (qs & qm[il/2][2]);
+ s5 += yl[l+17] * (qs & qm[il/2][3]);
+ s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
+ }
+ float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+ float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+ sumf1[row] += d1 * (scales[0] - 32);
+ sumf2[row] += d2 * (scales[2] - 32);
+
+ s1 = s2 = s3 = s4 = s5 = s6 = 0;
+ for (int l = 0; l < n; l += 2) {
+ const int32_t qs = q[l/2+8];
+ s1 += yl[l+8] * (qs & qm[il/2][0]);
+ s2 += yl[l+9] * (qs & qm[il/2][1]);
+ s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
+ s4 += yl[l+24] * (qs & qm[il/2][2]);
+ s5 += yl[l+25] * (qs & qm[il/2][3]);
+ s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
+ }
+ d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+ d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+ sumf1[row] += d1 * (scales[1] - 32);
+ sumf2[row] += d2 * (scales[3] - 32);
+
+ q += step;
+ h += step;
+ a += step;
+ dh += step;
+
+ }
+
+ y1 += 4 * QK_K;
+
+ }
+
+ for (int row = 0; row < 2; ++row) {
+ const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
+ sumf1[row] = simd_sum(sumf);
+ }
+ if (tiisg == 0) {
+ for (int row = 0; row < 2; ++row) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_q3_K_f32")]]
+kernel void kernel_mul_mv_q3_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q4_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
+ const int ix = tiisg/8; // 0...3
+ const int it = tiisg%8; // 0...7
+ const int iq = it/4; // 0 or 1
+ const int ir = it%4; // 0...3
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+ //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int first_row = r0 * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[16];
+ float yh[16];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int step = sizeof(block_q4_K) * nb / 2;
+
+ device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
+
+ uint16_t sc16[4];
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
+
+ for (int ib = ix; ib < nb; ib += 4) {
+
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
+ yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
+ yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
+ yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
+ }
+
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
+ device const half * dh = &x[ib].d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ sc16[0] = sc[0] & kmask1;
+ sc16[1] = sc[2] & kmask1;
+ sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
+ sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
+
+ device const uint16_t * q2 = q1 + 32;
+
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+ for (int i = 0; i < 8; i += 2) {
+ acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
+ acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
+ acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
+ acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
+ acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
+ acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
+ acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
+ acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
+ }
+
+ float dall = dh[0];
+ float dmin = dh[1];
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
+ (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
+ (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+ q1 += step;
+ sc += step;
+ dh += step;
+ }
+
+ y4 += 4 * QK_K;
+ }
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_q4_K_f32")]]
+kernel void kernel_mul_mv_q4_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q5_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const int nb = ne00/QK_K;
+
+ const int64_t r0 = tgpig.x;
+ const int64_t r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float sumf[2]={0.f};
+
+ const int step = sizeof(block_q5_K) * nb;
+
+ float yl[16], yh[16];
+
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
+ const int tid = tiisg/4;
+ const int ix = tiisg%4;
+ const int iq = tid/4;
+ const int ir = tid%4;
+ const int n = 8;
+
+ const int l0 = n*ir;
+ const int q_offset = 32*iq + l0;
+ const int y_offset = 64*iq + l0;
+
+ const uint8_t hm1 = 1u << (2*iq);
+ const uint8_t hm2 = hm1 << 1;
+ const uint8_t hm3 = hm1 << 4;
+ const uint8_t hm4 = hm2 << 4;
+
+ uint16_t sc16[4];
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
+
+ device const float * y1 = yy + ix*QK_K + y_offset;
+
+ for (int i = ix; i < nb; i += 4) {
+
+ device const uint8_t * q1 = x[i].qs + q_offset;
+ device const uint8_t * qh = x[i].qh + l0;
+ device const half * dh = &x[i].d;
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
+
+ device const float * y2 = y1 + 128;
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
+ for (int l = 0; l < 8; ++l) {
+ yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
+ yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
+ yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
+ yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
+ }
+
+ for (int row = 0; row < 2; ++row) {
+
+ device const uint8_t * q2 = q1 + 64;
+
+ sc16[0] = a[0] & kmask1;
+ sc16[1] = a[2] & kmask1;
+ sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
+ sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
+
+ float4 acc1 = {0.f};
+ float4 acc2 = {0.f};
+ for (int l = 0; l < n; ++l) {
+ uint8_t h = qh[l];
+ acc1[0] += yl[l+0] * (q1[l] & 0x0F);
+ acc1[1] += yl[l+8] * (q1[l] & 0xF0);
+ acc1[2] += yh[l+0] * (q2[l] & 0x0F);
+ acc1[3] += yh[l+8] * (q2[l] & 0xF0);
+ acc2[0] += h & hm1 ? yl[l+0] : 0.f;
+ acc2[1] += h & hm2 ? yl[l+8] : 0.f;
+ acc2[2] += h & hm3 ? yh[l+0] : 0.f;
+ acc2[3] += h & hm4 ? yh[l+8] : 0.f;
+ }
+ const float dall = dh[0];
+ const float dmin = dh[1];
+ sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
+ sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
+ sc8[4] * (acc1[2] + 16.f*acc2[2]) +
+ sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+ q1 += step;
+ qh += step;
+ dh += step/2;
+ a += step/2;
+
+ }
+
+ y1 += 4 * QK_K;
+
+ }
+
+ for (int row = 0; row < 2; ++row) {
+ const float tot = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_q5_K_f32")]]
+kernel void kernel_mul_mv_q5_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q6_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const uint8_t kmask1 = 0x03;
+ const uint8_t kmask2 = 0x0C;
+ const uint8_t kmask3 = 0x30;
+ const uint8_t kmask4 = 0xC0;
+
+ const int nb = ne00/QK_K;
+
+ const int64_t r0 = tgpig.x;
+ const int64_t r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int row = 2 * r0 + sgitg;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
+ device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float sumf = 0;
+
+ const int tid = tiisg/2;
+ const int ix = tiisg%2;
+ const int ip = tid/8; // 0 or 1
+ const int il = tid%8;
+ const int n = 4;
+ const int l0 = n*il;
+ const int is = 8*ip + l0/16;
+
+ const int y_offset = 128*ip + l0;
+ const int q_offset_l = 64*ip + l0;
+ const int q_offset_h = 32*ip + l0;
+
+ for (int i = ix; i < nb; i += 2) {
+
+ device const uint8_t * q1 = x[i].ql + q_offset_l;
+ device const uint8_t * q2 = q1 + 32;
+ device const uint8_t * qh = x[i].qh + q_offset_h;
+ device const int8_t * sc = x[i].scales + is;
+
+ device const float * y = yy + i * QK_K + y_offset;
+
+ const float dall = x[i].d;
+
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
+ for (int l = 0; l < n; ++l) {
+ sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+ sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+ sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+ sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+ }
+
+ sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+
+ }
+
+ const float tot = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + row] = tot;
+ }
+}
+
+[[host_name("kernel_mul_mv_q6_K_f32")]]
+kernel void kernel_mul_mv_q6_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+// ======================= "True" 2-bit
+
+void kernel_mul_mv_iq2_xxs_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
+ {
+ int nval = 4;
+ int pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
+ nval = 2;
+ pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq2_xxs * xr = x + ibl;
+ device const uint16_t * q2 = xr->qs + 4 * ib;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ const float db = dh[0];
+ device const uint8_t * aux8 = (device const uint8_t *)q2;
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
+ const float d = db * (0.5f + (aux32 >> 28));
+
+ float sum = 0;
+ for (int l = 0; l < 4; ++l) {
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
+ const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
+ for (int j = 0; j < 8; ++j) {
+ sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+ }
+ }
+ sumf[row] += d * sum;
+
+ dh += nb*sizeof(block_iq2_xxs)/2;
+ q2 += nb*sizeof(block_iq2_xxs)/2;
+ }
+
+ y4 += 32 * 32;
+ }
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
+kernel void kernel_mul_mv_iq2_xxs_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_iq2_xs_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
+ {
+ int nval = 8;
+ int pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
+ nval = 2;
+ pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq2_xs * xr = x + ibl;
+ device const uint16_t * q2 = xr->qs + 4 * ib;
+ device const uint8_t * sc = xr->scales + ib;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ const float db = dh[0];
+ const uint8_t ls1 = sc[0] & 0xf;
+ const uint8_t ls2 = sc[0] >> 4;
+ const float d1 = db * (0.5f + ls1);
+ const float d2 = db * (0.5f + ls2);
+
+ float sum1 = 0, sum2 = 0;
+ for (int l = 0; l < 2; ++l) {
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
+ for (int j = 0; j < 8; ++j) {
+ sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+ }
+ }
+ for (int l = 2; l < 4; ++l) {
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
+ for (int j = 0; j < 8; ++j) {
+ sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+ }
+ }
+ sumf[row] += d1 * sum1 + d2 * sum2;
+
+ dh += nb*sizeof(block_iq2_xs)/2;
+ q2 += nb*sizeof(block_iq2_xs)/2;
+ sc += nb*sizeof(block_iq2_xs);
+ }
+
+ y4 += 32 * 32;
+ }
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq2_xs_f32")]]
+kernel void kernel_mul_mv_iq2_xs_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_iq3_xxs_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
+ {
+ int nval = 4;
+ int pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i];
+ nval = 2;
+ pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq3_xxs * xr = x + ibl;
+ device const uint8_t * q3 = xr->qs + 8 * ib;
+ device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ const float db = dh[0];
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
+ const float d = db * (0.5f + (aux32 >> 28));
+
+ float2 sum = {0};
+ for (int l = 0; l < 4; ++l) {
+ const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]);
+ const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]);
+ const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
+ for (int j = 0; j < 4; ++j) {
+ sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+ sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+ }
+ }
+ sumf[row] += d * (sum[0] + sum[1]);
+
+ dh += nb*sizeof(block_iq3_xxs)/2;
+ q3 += nb*sizeof(block_iq3_xxs);
+ gas += nb*sizeof(block_iq3_xxs)/2;
+ }
+
+ y4 += 32 * 32;
+ }
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
+kernel void kernel_mul_mv_iq3_xxs_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_iq3_s_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
+ {
+ int nval = 8;
+ int pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq3_s * xr = x + ibl;
+ device const uint8_t * qs = xr->qs + 8 * ib;
+ device const uint8_t * qh = xr->qh + ib;
+ device const uint8_t * sc = xr->scales + (ib/2);
+ device const uint8_t * signs = xr->signs + 4 * ib;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ const float db = dh[0];
+ const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
+
+ float2 sum = {0};
+ for (int l = 0; l < 4; ++l) {
+ const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
+ const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
+ const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
+ const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
+ for (int j = 0; j < 4; ++j) {
+ sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
+ sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
+ }
+ }
+ sumf[row] += d * (sum[0] + sum[1]);
+
+ dh += nb*sizeof(block_iq3_s)/2;
+ qs += nb*sizeof(block_iq3_s);
+ qh += nb*sizeof(block_iq3_s);
+ sc += nb*sizeof(block_iq3_s);
+ signs += nb*sizeof(block_iq3_s);
+ }
+
+ y4 += 32 * 32;
+ }
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq3_s_f32")]]
+kernel void kernel_mul_mv_iq3_s_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_iq2_s_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
+ //{
+ // int nval = 32;
+ // int pos = (32*sgitg + tiisg)*nval;
+ // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i];
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
+ //}
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq2_s * xr = x + ibl;
+ device const uint8_t * qs = xr->qs + 4 * ib;
+ device const uint8_t * qh = xr->qh + ib;
+ device const uint8_t * sc = xr->scales + ib;
+ device const uint8_t * signs = qs + QK_K/8;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ const float db = dh[0];
+ const float d1 = db * (0.5f + (sc[0] & 0xf));
+ const float d2 = db * (0.5f + (sc[0] >> 4));
+
+ float2 sum = {0};
+ for (int l = 0; l < 2; ++l) {
+ //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
+ //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
+ constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
+ for (int j = 0; j < 8; ++j) {
+ sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
+ sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
+ }
+ }
+ sumf[row] += d1 * sum[0] + d2 * sum[1];
+
+ dh += nb*sizeof(block_iq2_s)/2;
+ qs += nb*sizeof(block_iq2_s);
+ qh += nb*sizeof(block_iq2_s);
+ sc += nb*sizeof(block_iq2_s);
+ signs += nb*sizeof(block_iq2_s);
+ }
+
+ y4 += 32 * 32;
+ }
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq2_s_f32")]]
+kernel void kernel_mul_mv_iq2_s_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_iq1_s_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_value,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ float sumy = 0;
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ sumy += yl[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq1_s * xr = x + ibl;
+ device const uint8_t * qs = xr->qs + 4 * ib;
+ device const uint16_t * qh = xr->qh + ib;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
+ constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
+ constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
+
+ float sum = 0;
+ for (int j = 0; j < 4; ++j) {
+ sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+ + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
+ + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
+ + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
+ }
+ sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
+
+ dh += nb*sizeof(block_iq1_s)/2;
+ qs += nb*sizeof(block_iq1_s);
+ qh += nb*sizeof(block_iq1_s)/2;
+ }
+
+ y4 += 32 * 32;
+ }
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ }
+ }
+}
+
+void kernel_mul_mv_iq1_m_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_value,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ iq1m_scale_t scale;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ float4 sumy = {0.f};
+ for (int i = 0; i < 8; ++i) {
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+ yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
+ yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
+ yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq1_m * xr = x + ibl;
+ device const uint8_t * qs = xr->qs + 4 * ib;
+ device const uint8_t * qh = xr->qh + 2 * ib;
+ device const uint16_t * sc = (device const uint16_t *)xr->scales;
+
+ for (int row = 0; row < N_DST; row++) {
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+ constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
+ constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
+
+ float2 sum = {0.f};
+ for (int j = 0; j < 4; ++j) {
+ sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+ + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
+ sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
+ + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
+ }
+ const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+ const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+
+ sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
+ (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
+
+ sc += nb*sizeof(block_iq1_m)/2;
+ qs += nb*sizeof(block_iq1_m);
+ qh += nb*sizeof(block_iq1_m);
+ }
+
+ y4 += 32 * 32;
+ }
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ }
+ }
+}
+
+static inline float iq1bn_fp8_to_float(uint8_t fp8) {
+ typedef union { float f; uint32_t i; } scale_t;
+ scale_t s; s.i = (((fp8 >> 5) + 116) << 23) | ((fp8 & 0x1f) << 18);
+ return s.f;
+}
+
+//static constant int8_t iq1bn_values[256*5] = {
+// -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 0, -1, -1, -1, 0, 0, -1, -1, -1, 1, 0,
+// -1, -1, -1, -1, 1, -1, -1, -1, 0, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, 0, -1, -1, 0, -1, 0, -1, -1, 1, -1, 0, -1,
+// -1, -1, 0, 0, -1, -1, 0, 0, 0, -1, -1, 1, 0, 0, -1, -1, -1, 1, 0, -1, -1, 0, 1, 0, -1, -1, 1, 1, 0, -1, -1, -1,
+// -1, 1, -1, -1, 0, 0, 0, 0, 0, 0, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, 0, 1, -1, -1, 0, 0, 1, -1, -1, 1, 0, 1,
+// -1, -1, -1, 1, 1, -1, -1, 0, 1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 0, -1, 0, -1, -1, 0, -1, 1, -1, -1, 0, -1,
+// -1, 0, -1, 0, -1, 0, 0, -1, 0, -1, 1, 0, -1, 0, -1, -1, 1, -1, 0, -1, 0, 1, -1, 0, -1, 1, 1, -1, 0, -1, -1, -1,
+// 0, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 0, -1, -1, 0, 0, 0, -1, 0, 0, 0, 0, -1, 1, 0, 0, 0,
+// -1, -1, 1, 0, 0, -1, 0, 1, 0, 0, -1, 1, 1, 0, 0, -1, -1, -1, 1, 0, -1, 0, -1, 1, 0, -1, 1, -1, 1, 0, -1, -1,
+// 0, 1, 0, -1, 0, 0, 1, 0, -1, 1, 0, 1, 0, -1, -1, 1, 1, 0, -1, 0, 1, 1, 0, -1, 1, 1, 1, 0, -1, -1, -1, -1,
+// 1, -1, 0, -1, -1, 1, -1, 1, -1, -1, 1, -1, 0, 0, 0, 0, 0, -1, 0, -1, 1, -1, 0, 0, -1, 1, -1, 1, 0, -1, 1, -1,
+// -1, 1, -1, 1, -1, 0, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 0, 1, -1, 0, -1, 0, 1, -1, 1, -1, 0, 1, -1, -1, 0,
+// 0, 1, -1, 0, 0, 0, 1, -1, 1, 0, 0, 1, -1, -1, 1, 0, 1, -1, 0, 1, 0, 1, -1, 1, 1, 0, 1, -1, -1, -1, 1, 1,
+// -1, 0, -1, 1, 1, -1, 1, -1, 1, 1, -1, 0, 0, 0, 0, 0, -1, 0, 1, 1, -1, 0, 0, 1, 1, -1, 1, 0, 1, 1, -1, -1,
+// 1, 1, 1, -1, 0, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, 0, 0, -1, -1, -1, 0, 1, -1, -1, -1, 0, -1, 0, -1,
+// -1, 0, 0, 0, -1, -1, 0, 1, 0, -1, -1, 0, -1, 1, -1, -1, 0, 0, 1, -1, -1, 0, 1, 1, -1, -1, 0, -1, -1, 0, -1, 0,
+// 0, -1, 0, -1, 0, 1, -1, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 1, 0, 0, -1, 0, -1, 1,
+// 0, -1, 0, 0, 1, 0, -1, 0, 1, 1, 0, -1, 0, -1, -1, 1, -1, 0, 0, -1, 1, -1, 0, 1, -1, 1, -1, 0, -1, 0, 1, -1,
+// 0, 0, 0, 1, -1, 0, 1, 0, 1, -1, 0, -1, 1, 1, -1, 0, 0, 1, 1, -1, 0, 1, 1, 1, -1, 0, -1, -1, -1, 0, 0, 0,
+// -1, -1, 0, 0, 1, -1, -1, 0, 0, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, -1, 1, -1,
+// 0, 0, 0, 1, -1, 0, 0, 1, 1, -1, 0, 0, -1, -1, 0, 0, 0, 0, -1, 0, 0, 0, 1, -1, 0, 0, 0, -1, 0, 0, 0, 0,
+// 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, -1, -1, 1, 0, 0, 0, -1,
+// 1, 0, 0, 1, -1, 1, 0, 0, -1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, -1, 1, 1, 0,
+// 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, -1, -1, -1, 1, 0, 0, -1, -1, 1, 0, 1, -1, -1, 1, 0, -1, 0, -1, 1, 0, 0,
+// 0, -1, 1, 0, 1, 0, -1, 1, 0, -1, 1, -1, 1, 0, 0, 1, -1, 1, 0, 1, 1, -1, 1, 0, -1, -1, 0, 1, 0, 0, -1, 0,
+// 1, 0, 1, -1, 0, 1, 0, -1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, -1, 1, 0, 1, 0,
+// 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, -1, -1, 1, 1, 0, 0, -1, 1, 1, 0, 1, -1, 1, 1, 0, -1, 0, 1, 1, 0, 0, 0,
+// 1, 1, 0, 1, 0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, -1, -1, -1, -1, 1, 0, -1, -1, -1,
+// 1, 1, -1, -1, -1, 1, -1, 0, -1, -1, 1, 0, 0, -1, -1, 1, 1, 0, -1, -1, 1, -1, 1, -1, -1, 1, 0, 0, 0, 0, 0, 0,
+// 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 0, -1, 1, 0, -1, 0, -1, 1, 1, -1, 0, -1, 1, -1, 0, 0, -1, 1, 0, 0, 0,
+// -1, 1, 1, 0, 0, -1, 1, -1, 1, 0, -1, 1, 0, 1, 0, -1, 1, 1, 1, 0, -1, 1, -1, -1, 1, -1, 1, 0, -1, 1, -1, 1,
+// 1, -1, 1, -1, 1, -1, 0, 1, -1, 1, 0, 0, 1, -1, 1, 1, 0, 1, -1, 1, -1, 1, 1, -1, 1, 0, 0, 0, 0, 0, 0, 1,
+// 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, -1, 0, 1, 0, -1, -1, 0, 1, 1, -1, -1, 0, 1, -1, 0, -1, 0, 1, 0, 0, -1, 0,
+// 1, 1, 0, -1, 0, 1, -1, 1, -1, 0, 1, 0, 1, -1, 0, 1, 1, 1, -1, 0, 1, -1, -1, 0, 0, 1, 0, -1, 0, 0, 1, 1,
+// -1, 0, 0, 1, -1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, -1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0,
+// 0, 0, 1, 1, 0, 0, 1, -1, -1, 1, 0, 1, 0, -1, 1, 0, 1, 1, -1, 1, 0, 1, -1, 0, 1, 0, 1, 0, 0, 1, 0, 1,
+// 1, 0, 1, 0, 1, -1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, -1, -1, -1, 1, 1, 0, -1, -1, 1, 1, 1, -1,
+// -1, 1, 1, -1, 0, -1, 1, 1, 0, 0, -1, 1, 1, 1, 0, -1, 1, 1, -1, 1, -1, 1, 1, 0, 1, -1, 1, 1, 1, 1, -1, 1,
+// 1, 0, 0, 0, 0, 0, -1, -1, 0, 1, 1, 0, -1, 0, 1, 1, 1, -1, 0, 1, 1, -1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1,
+// 0, 0, 1, 1, -1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, -1, -1, 1, 1, 1, 0, -1, 1, 1, 1, 1, -1, 1,
+// 1, 1, -1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, -1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+//};
+
+void kernel_mul_mv_iq1_bn_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_value,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const int nb = ne00/QK_IQ1BN;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ device const block_iq1_bn * x = (device const block_iq1_bn *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[16];
+ float sumf[N_DST]={0.f};
+
+ const int nb32 = nb * (QK_IQ1BN / 32);
+
+ const int ix = tiisg/2;
+ const int ir = tiisg%2;
+
+ device const float * y4 = (device const float *)y + 32 * ix + 16 * ir;
+
+ uint32_t aux32[2];
+ thread const uint8_t * aux8 = (thread const uint8_t *)aux32;
+
+ const float values[3] = {-1.f, 0.f, 1.f};
+
+ constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
+
+ for (int j = 0; j < 16; ++j) yl[j] = y4[j];
+
+ const int ibl = ib32 / (QK_IQ1BN / 32);
+ const int ib = ib32 % (QK_IQ1BN / 32);
+ const int i16 = 2*ib + ir;
+
+ device const block_iq1_bn * xr = x + ibl;
+ device const uint8_t * ql = xr->ql + 3*i16;
+ device const uint8_t * extra = (device const uint8_t *)&xr->extra;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ float acc = 0;
+ int i = 0;
+ for (int k = 0; k < 3; ++k) {
+ //constant int8_t * vs = iq1bn_values + 5*ql[k];
+ //for (int j = 0; j < 5; ++j) acc += yl[5*k+j]*vs[j];
+ uint8_t q = ql[k];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ v = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ acc += yl[i++] * values[v];
+ }
+ }
+ //constant int8_t * vs = iq1bn_values + 5*extra[0];
+ //acc += yl[15] * vs[i16];
+ uint8_t v = k_mult[i16]*extra[0];
+ v = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ acc += yl[15] * values[v];
+
+ sumf[row] += acc;
+
+ extra += nb*sizeof(block_iq1_bn);
+ ql += nb*sizeof(block_iq1_bn);
+ }
+
+ y4 += 32 * 16;
+ }
+
+ for (int row = 0; row < N_DST; row += 2) {
+ half2 r = {(half)sumf[row], (half)sumf[row+1]};
+ r = simd_sum(r);
+ if (tiisg < 2) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = r[tiisg];
+ }
+ }
+}
+
+// TODO
+void kernel_mul_mv_iq2_bn_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_value,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ const int nb = ne00/QK_IQ1BN;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ device const block_iq2_bn * x = (device const block_iq2_bn *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[16];
+ float sumf[N_DST]={0.f};
+
+ const int nb32 = nb * (QK_IQ1BN / 32);
+
+ const int ix = tiisg/4; // 0...7
+ const int ir = tiisg%4; // 0...3
+
+ device const float * y4 = y + 64 * ix + 4 * ir;
+
+ for (int ib = ix; ib < nb; ib += 8) {
+
+ float sumy = 0.f;
+ for (int i = 0; i < 4; ++i) {
+ yl[i+ 0] = y4[i+ 0]; sumy += yl[i+ 0];
+ yl[i+ 4] = y4[i+16]; sumy += yl[i+ 4];
+ yl[i+ 8] = y4[i+32]; sumy += yl[i+ 8];
+ yl[i+12] = y4[i+48]; sumy += yl[i+12];
+ }
+
+ device const uint8_t * qs = x[ib].qs + 4*ir;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ float4 acc = {0.f};
+ for (int j = 0; j < 4; ++j) {
+ acc[0] += yl[j+ 0] * (qs[j] & 0x03);
+ acc[1] += yl[j+ 4] * (qs[j] & 0x0c);
+ acc[2] += yl[j+ 8] * (qs[j] & 0x30);
+ acc[3] += yl[j+12] * (qs[j] & 0xc0);
+ }
+
+ sumf[row] += acc[0] + 0.25f*acc[1] + 0.0625*acc[2] + 0.015625f*acc[3] - sumy;
+
+ qs += nb*sizeof(block_iq2_bn);
+ }
+
+ y4 += 64 * 8;
+ }
+
+ for (int row = 0; row < N_DST; row += 2) {
+ half2 r = {(half)sumf[row], (half)sumf[row+1]};
+ r = simd_sum(r);
+ if (tiisg < 2) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row + tiisg] = r[tiisg];
+ }
+ }
+}
+
+void kernel_mul_mv_iq4_nl_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values_i8,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
+ const int nb = ne00/QK4_NL;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+ const int first_row = (r0 * 2 + sgitg) * 2;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ const int ix = tiisg/2; // 0...15
+ const int it = tiisg%2; // 0 or 1
+
+ shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ float4 yl[4];
+ float sumf[2]={0.f}, all_sum;
+
+ device const float * yb = y + ix * QK4_NL + it * 8;
+
+ uint32_t aux32[2];
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
+
+ float4 qf1, qf2;
+
+ for (int ib = ix; ib < nb; ib += 16) {
+
+ device const float4 * y4 = (device const float4 *)yb;
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+
+ for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
+
+ device const block_iq4_nl & xb = x[row*nb + ib];
+ device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
+
+ float4 acc1 = {0.f}, acc2 = {0.f};
+
+ aux32[0] = q4[0] | (q4[1] << 16);
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+ aux32[0] &= 0x0f0f0f0f;
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+ acc1 += yl[0] * qf1;
+ acc2 += yl[1] * qf2;
+
+ aux32[0] = q4[2] | (q4[3] << 16);
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+ aux32[0] &= 0x0f0f0f0f;
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+ acc1 += yl[2] * qf1;
+ acc2 += yl[3] * qf2;
+
+ acc1 += acc2;
+
+ sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
+
+ }
+
+ yb += 16 * QK4_NL;
+ }
+
+ for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ }
+ }
+}
+
+void kernel_mul_mv_iq4_xs_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values_i8,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+ const int first_row = (r0 * 2 + sgitg) * 2;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+ device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ const int ix = tiisg/16; // 0 or 1
+ const int it = tiisg%16; // 0...15
+ const int ib = it/2;
+ const int il = it%2;
+
+ shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ float4 yl[4];
+ float sumf[2]={0.f}, all_sum;
+
+ device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
+
+ uint32_t aux32[2];
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
+
+ float4 qf1, qf2;
+
+ for (int ibl = ix; ibl < nb; ibl += 2) {
+
+ device const float4 * y4 = (device const float4 *)yb;
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+
+ for (int row = 0; row < 2; ++row) {
+
+ device const block_iq4_xs & xb = x[row*nb + ibl];
+ device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
+
+ float4 acc1 = {0.f}, acc2 = {0.f};
+
+ aux32[0] = q4[0] & 0x0f0f0f0f;
+ aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+ acc1 += yl[0] * qf1;
+ acc2 += yl[1] * qf2;
+
+ aux32[0] = q4[1] & 0x0f0f0f0f;
+ aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+ acc1 += yl[2] * qf1;
+ acc2 += yl[3] * qf2;
+
+ acc1 += acc2;
+
+ const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
+ sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
+
+ }
+
+ yb += 2 * QK_K;
+ }
+
+ for (int row = 0; row < 2; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq1_s_f32")]]
+kernel void kernel_mul_mv_iq1_s_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq1_m_f32")]]
+kernel void kernel_mul_mv_iq1_m_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq1_bn_f32")]]
+kernel void kernel_mul_mv_iq1_bn_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq1_bn_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq2_bn_f32")]]
+kernel void kernel_mul_mv_iq2_bn_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq2_bn_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq4_nl_f32")]]
+kernel void kernel_mul_mv_iq4_nl_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq4_xs_f32")]]
+kernel void kernel_mul_mv_iq4_xs_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+//============================= templates and their specializations =============================
+
+// NOTE: this is not dequantizing - we are simply fitting the template
+template <typename type4x4>
+void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
+ float4x4 temp = *(((device float4x4 *)src));
+ for (int i = 0; i < 16; i++){
+ reg[i/4][i%4] = temp[i/4][i%4];
+ }
+}
+
+template <typename type4x4>
+void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
+ half4x4 temp = *(((device half4x4 *)src));
+ for (int i = 0; i < 16; i++){
+ reg[i/4][i%4] = temp[i/4][i%4];
+ }
+}
+
+template <typename type4x4>
+void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float md = -8.h * xb->d;
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
+ const ushort mask1 = mask0 << 8;
+
+ for (int i=0;i<8;i++) {
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float m = xb->m;
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
+ const ushort mask1 = mask0 << 8;
+
+ for (int i=0;i<8;i++) {
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+ const float d = xb->d;
+ const float md = -16.h * xb->d;
+ const ushort mask = il ? 0x00F0 : 0x000F;
+
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+ const int x_mv = il ? 4 : 0;
+
+ const int gh_mv = il ? 12 : 0;
+ const int gh_bk = il ? 0 : 4;
+
+ for (int i = 0; i < 8; i++) {
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+ const float d = xb->d;
+ const float m = xb->m;
+ const ushort mask = il ? 0x00F0 : 0x000F;
+
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+ const int x_mv = il ? 4 : 0;
+
+ const int gh_mv = il ? 12 : 0;
+ const int gh_bk = il ? 0 : 4;
+
+ for (int i = 0; i < 8; i++) {
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
+ const half d = xb->d;
+
+ for (int i = 0; i < 16; i++) {
+ reg[i/4][i%4] = (qs[i + 16*il] * d);
+ }
+}
+
+template <typename type4x4>
+void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
+ const float d = xb->d;
+ const float min = xb->dmin;
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
+ float dl, ml;
+ uint8_t sc = xb->scales[il];
+
+ q = q + 32*(il/8) + 16*(il&1);
+ il = (il/2)%4;
+
+ half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
+ const half d_all = xb->d;
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
+ device const uint8_t * h = (device const uint8_t *)xb->hmask;
+ device const int8_t * scales = (device const int8_t *)xb->scales;
+
+ q = q + 32 * (il/8) + 16 * (il&1);
+ h = h + 16 * (il&1);
+ uint8_t m = 1 << (il/2);
+ uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
+ ((il/4)>0 ? 12 : 3);
+ uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
+ uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
+ : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
+ const float ml = 4.f * dl;
+
+ il = (il/2) & 3;
+ const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ dl *= coef;
+
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
+ }
+}
+
+static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
+ return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
+ : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
+}
+
+template <typename type4x4>
+void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
+ device const uchar * q = xb->qs;
+
+ short is = (il/4) * 2;
+ q = q + (il/4) * 32 + 16 * (il&1);
+ il = il & 3;
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
+ const float min = xb->dmin;
+ const float dl = d * sc[0];
+ const float ml = min * sc[1];
+
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
+ device const uint8_t * q = xb->qs;
+ device const uint8_t * qh = xb->qh;
+
+ short is = (il/4) * 2;
+ q = q + 32 * (il/4) + 16 * (il&1);
+ qh = qh + 16 * (il&1);
+ uint8_t ul = 1 << (il/2);
+ il = il & 3;
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+ const float d = il < 2 ? xb->d : xb->d / 16.f;
+ const float min = xb->dmin;
+ const float dl = d * sc[0];
+ const float ml = min * sc[1];
+
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ const float qh_val = il<2 ? 16.f : 256.f;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
+ const half d_all = xb->d;
+ device const uint8_t * ql = (device const uint8_t *)xb->ql;
+ device const uint8_t * qh = (device const uint8_t *)xb->qh;
+ device const int8_t * scales = (device const int8_t *)xb->scales;
+
+ ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
+ qh = qh + 32*(il/8) + 16*(il&1);
+ float sc = scales[(il%2) + 2 * ((il/2))];
+ il = (il/2) & 3;
+
+ const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
+ const float coef = il>1 ? 1.f/16.f : 1.f;
+ const float ml = d_all * sc * 32.f;
+ const float dl = d_all * sc * coef;
+ for (int i = 0; i < 16; ++i) {
+ const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
+ : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
+ reg[i/4][i%4] = dl * q - ml;
+ }
+}
+
+template <typename type4x4>
+void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
+ device const uint16_t * q2 = xb->qs + 4*ib32;
+ const uint32_t aux32_g = q2[0] | (q2[1] << 16);
+ const uint32_t aux32_s = q2[2] | (q2[3] << 16);
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
+ const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
+ constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
+ uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
+ for (int i = 0; i < 8; ++i) {
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+ grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
+ signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
+ for (int i = 0; i < 8; ++i) {
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+}
+
+template <typename type4x4>
+void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint16_t * q2 = xb->qs + 4*ib32;
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
+ constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
+ uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
+ for (int i = 0; i < 8; ++i) {
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+ grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
+ signs = ksigns_iq2xs[q2[2*il+1] >> 9];
+ for (int i = 0; i < 8; ++i) {
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+}
+
+template <typename type4x4>
+void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint8_t * q3 = xb->qs + 8*ib32;
+ device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
+ const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
+ uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
+ for (int i = 0; i < 4; ++i) {
+ reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+ reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+ }
+ grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
+ grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
+ signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
+ for (int i = 0; i < 4; ++i) {
+ reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+ reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+ }
+}
+
+template <typename type4x4>
+void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint8_t * qs = xb->qs + 8*ib32;
+ device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
+ const uint8_t qh = xb->qh[ib32] >> 4*il;
+ const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
+ for (int i = 0; i < 4; ++i) {
+ reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
+ reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
+ }
+ grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
+ grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
+ for (int i = 0; i < 4; ++i) {
+ reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
+ reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
+ }
+}
+
+template <typename type4x4>
+void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+ device const uint8_t * signs = qs + QK_K/8;
+ const uint8_t qh = xb->qh[ib32] >> 4*il;
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
+ constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
+ for (int i = 0; i < 8; ++i) {
+ reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
+ reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
+ }
+}
+
+template <typename type4x4>
+void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const int ib32 = il/2;
+ il = il%2;
+ const float d = xb->d;
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+ device const uint16_t * qh = xb->qh;
+ const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
+ const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
+ const uint16_t h = qh[ib32] >> 6*il;
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
+ for (int i = 0; i < 4; ++i) {
+ reg[0][i] = dl * (grid1[i] & 0xf) + ml;
+ reg[1][i] = dl * (grid1[i] >> 4) + ml;
+ reg[2][i] = dl * (grid2[i] & 0xf) + ml;
+ reg[3][i] = dl * (grid2[i] >> 4) + ml;
+ }
+}
+
+template <typename type4x4>
+void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const int ib32 = il/2;
+ il = il%2;
+ device const uint16_t * sc = (device const uint16_t *)xb->scales;
+
+ iq1m_scale_t scale;
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+ const float d = scale.f16;
+
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+ device const uint8_t * qh = xb->qh + 2*ib32 + il;
+
+ const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
+ const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+ const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+ for (int i = 0; i < 4; ++i) {
+ reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
+ reg[1][i] = dl * (grid1[i] >> 4) + ml1;
+ reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
+ reg[3][i] = dl * (grid2[i] >> 4) + ml2;
+ }
+}
+
+
+template <typename type4x4>
+void dequantize_iq1_bn(device const block_iq1_bn * xb, short il, thread type4x4 & reg) {
+ // il is in 0...3
+
+ constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
+
+ int i = 0;
+ for (int k = 0; k < 3; ++k) {
+ uint8_t q = xb->ql[3*il + k];
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = k_mult[j]*q;
+ int8_t vs = 3*v >> 8;
+ //int8_t vs = (v + (v >> 1)) >> 7;
+ reg[i/4][i%4] = vs - 1;
+ ++i;
+ }
+ }
+ uint8_t v = k_mult[il]*xb->extra;
+ int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
+ reg[3][3] = vs - 1;
+}
+
+template <typename type4x4>
+void dequantize_iq2_bn(device const block_iq2_bn * xb, short il, thread type4x4 & reg) {
+ // il is in 0...3
+ constexpr float k_scale[4] = {1.f, 0.25f, 0.0625f, 0.015625f};
+ constexpr uint8_t k_mask[4] = {0x03, 0x0c, 0x30, 0xc0};
+ const float d = k_scale[il];
+ uint8_t mask = k_mask[il];
+
+ for (int j = 0; j < 16; ++j) {
+ reg[j/4][j%4] = d * (xb->qs[j] & mask) - 1;
+ }
+}
+
+template <typename type4x4>
+void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
+ device const uint16_t * q4 = (device const uint16_t *)xb->qs;
+ const float d = xb->d;
+ uint32_t aux32;
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+ for (int i = 0; i < 4; ++i) {
+ aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+ }
+}
+
+template <typename type4x4>
+void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
+ const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
+ const float d = (float)xb->d * (ls - 32);
+ uint32_t aux32;
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+ for (int i = 0; i < 4; ++i) {
+ aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+ }
+}
+
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
+kernel void kernel_get_rows_q(
+ device const void * src0,
+ device const void * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
+ float4x4 temp;
+ dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+ }
+}
+
+template<typename T>
+kernel void kernel_get_rows_f(
+ device const void * src0,
+ device const void * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
+ }
+}
+
+kernel void kernel_get_rows_i32(
+ device const void * src0,
+ device const void * src1,
+ device int32_t * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
+ }
+}
+
+
+#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
+#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
+#define BLOCK_SIZE_K 32
+#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
+#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
+#define THREAD_PER_BLOCK 128
+#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
+#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
+#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
+#define SG_MAT_ROW 8
+
+// each block_q contains 16*nl weights
+template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
+kernel void kernel_mul_mm(device const uchar * src0,
+ device const uchar * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ threadgroup T * sa = (threadgroup T *)(shared_memory);
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
+
+ const uint r0 = tgpig.y;
+ const uint r1 = tgpig.x;
+ const uint im = tgpig.z;
+
+ // if this block is of 64x32 shape or smaller
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
+ // a thread shouldn't load data outside of the matrix
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+ simdgroup_T8x8 ma[4];
+ simdgroup_float8x8 mb[2];
+ simdgroup_float8x8 c_res[8];
+ for (int i = 0; i < 8; i++){
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
+ }
+
+ short il = (tiitg % THREAD_PER_ROW);
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
+ ushort offset1 = il/nl;
+
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+ device const float * y = (device const float *)(src1
+ + nb12 * im
+ + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+ // load data and store to threadgroup memory
+ T4x4 temp_a;
+ dequantize_func(x, il, temp_a);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ #pragma unroll(16)
+ for (int i = 0; i < 16; i++) {
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
+ }
+
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+
+ il = (il + 2 < nl) ? il + 2 : il % 2;
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
+ y += BLOCK_SIZE_K;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // load matrices from threadgroup memory and conduct outer products
+ threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+
+ #pragma unroll(4)
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+ #pragma unroll(4)
+ for (int i = 0; i < 4; i++) {
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
+ }
+ simdgroup_barrier(mem_flags::mem_none);
+ #pragma unroll(2)
+ for (int i = 0; i < 2; i++) {
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
+ }
+
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+
+ #pragma unroll(8)
+ for (int i = 0; i < 8; i++){
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
+ }
+ }
+ }
+
+ if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
+ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
+ for (int i = 0; i < 8; i++) {
+ simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
+ }
+ } else {
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
+ for (int i = 0; i < 8; i++) {
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
+ if (sgitg == 0) {
+ for (int i = 0; i < n_rows; i++) {
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+ *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
+ }
+ }
+ }
+ }
+}
+
+// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+void kernel_mul_mm_id_impl(
+ device const uchar * src0,
+ device const uchar * src1,
+ threadgroup ushort2 * rowids,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ int64_t ne1,
+ int64_t ne0ne1,
+ threadgroup uchar * shared_memory,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
+
+ const uint r0 = tgpig.y;
+ const uint r1 = tgpig.x;
+
+ if (r1 * BLOCK_SIZE_N >= ne1) return;
+
+ // if this block is of 64x32 shape or smaller
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
+ // a thread shouldn't load data outside of the matrix
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+ simdgroup_half8x8 ma[4];
+ simdgroup_float8x8 mb[2];
+ simdgroup_float8x8 c_res[8];
+ for (int i = 0; i < 8; i++){
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
+ }
+ short il = (tiitg % THREAD_PER_ROW);
+
+ ushort offset1 = il/nl;
+
+ threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
+
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
+ device const float * y = (device const float *)(src1
+ + nb12 * id[1]
+ + nb11 * (id[0] % ne11)
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+ // load data and store to threadgroup memory
+ half4x4 temp_a;
+ dequantize_func(x, il, temp_a);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int i = 0; i < 16; i++) {
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
+ }
+
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+
+ il = (il + 2 < nl) ? il + 2 : il % 2;
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
+ y += BLOCK_SIZE_K;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // load matrices from threadgroup memory and conduct outer products
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+ for (int i = 0; i < 4; i++) {
+ simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
+ }
+ simdgroup_barrier(mem_flags::mem_none);
+ for (int i = 0; i < 2; i++) {
+ simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
+ }
+
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+
+ for (int i = 0; i < 8; i++){
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
+ }
+ }
+ }
+
+ {
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
+ for (int i = 0; i < 8; i++) {
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ device float * C = dst + (BLOCK_SIZE_M * r0);
+ if (sgitg == 0) {
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+ threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
+ int joff = jid[0] * ne0 + jid[1] * ne0ne1;
+ for (int i = 0; i < n_rows; i++) {
+ *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
+ }
+ }
+ }
+ }
+}
+
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+kernel void kernel_mul_mm_id(
+ device const uchar * src0s,
+ device const uchar * src1,
+ device float * dst,
+ device const uchar * ids,
+ constant int64_t & nei0,
+ constant int64_t & nei1,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ const int32_t i02 = tgpig.z;
+ tgpig.z = 0;
+
+ device const uchar * src0 = src0s + i02*nb02;
+
+ // row indices
+ threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
+
+ // TODO: parallelize this loop
+ int64_t _ne1 = 0;
+ for (ushort ii1 = 0; ii1 < nei1; ii1++) {
+ for (ushort ii0 = 0; ii0 < nei0; ii0++) {
+ int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
+ if (id == i02) {
+ //if (tiitg == 0) {
+ rowids[_ne1] = ushort2(ii0, ii1);
+ //}
+ _ne1++;
+ }
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
+ src0,
+ src1,
+ rowids,
+ dst,
+ ne00,
+ ne02,
+ nb01,
+ nb02,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ _ne1,
+ ne0*ne1,
+ shared_memory,
+ tgpig,
+ tiitg,
+ sgitg);
+}
+
+#define QK_NL 16
+
+//
+// get rows
+//
+
+typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
+
+template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
+template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
+
+typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
+
+template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
+template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
+template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
+template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
+template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
+template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
+template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
+template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
+template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
+template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
+template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
+template [[host_name("kernel_get_rows_iq1_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_bn, 4, dequantize_iq1_bn>;
+template [[host_name("kernel_get_rows_iq2_bn")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_bn, 4, dequantize_iq2_bn>;
+
+//
+// matrix-matrix multiplication
+//
+
+typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
+
+template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
+template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
+template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
+template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
+template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
+template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
+template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
+template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
+template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
+template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
+template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
+template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_bn, 4, dequantize_iq1_bn>;
+template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_bn, 4, dequantize_iq2_bn>;
+
+//
+// indirect matrix-matrix multiplication
+//
+
+typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
+
+template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
+template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
+template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
+template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
+template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
+template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
+template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
+template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
+template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
+template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
+template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
+template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
+template [[host_name("kernel_mul_mm_id_iq1_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_bn, 4, dequantize_iq1_bn>;
+template [[host_name("kernel_mul_mm_id_iq2_bn_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_bn, 4, dequantize_iq2_bn>;
+template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
+template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
+
+//
+// matrix-vector multiplication
+//
+
+typedef void (kernel_mul_mv_impl_t)(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ int64_t ne10,
+ int64_t ne11,
+ int64_t ne12,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ uint3 tgpig,
+ uint tiisg);
+
+typedef void (kernel_mul_mv2_impl_t)(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg);
+
+template<kernel_mul_mv_impl_t impl_fn>
+void mmv_fn(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ int64_t ne10,
+ int64_t ne11,
+ int64_t ne12,
+ int64_t ne13,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int64_t ne0,
+ int64_t ne1,
+ uint64_t nb1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiitg,
+ uint tiisg,
+ uint sgitg) {
+ impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
+}
+
+template<kernel_mul_mv2_impl_t impl_fn>
+void mmv_fn(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ int64_t ne10,
+ int64_t ne11,
+ int64_t ne12,
+ int64_t ne13,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int64_t ne0,
+ int64_t ne1,
+ uint64_t nb1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiitg,
+ uint tiisg,
+ uint sgitg) {
+ impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
+}
+
+typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
+
+template<mul_mv_impl_fn_t impl_fn>
+kernel void kernel_mul_mv_id(
+ device const char * src0s,
+ device const char * src1,
+ device float * dst,
+ device const char * ids,
+ constant int64_t & nei0,
+ constant int64_t & nei1,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ const int iid1 = tgpig.z/nei0;
+ const int idx = tgpig.z%nei0;
+
+ tgpig.z = 0;
+
+ const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx];
+
+ const int64_t i11 = idx % ne11;
+ const int64_t i12 = iid1;
+
+ const int64_t i1 = idx;
+ const int64_t i2 = i12;
+
+ device const char * src0_cur = src0s + i02*nb02;
+ device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
+ device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
+
+ impl_fn(
+ /* src0 */ src0_cur,
+ /* src1 */ src1_cur,
+ /* dst */ dst_cur,
+ /* ne00 */ ne00,
+ /* ne01 */ ne01,
+ /* ne02 */ 1,//ne02,
+ /* nb00 */ nb00,
+ /* nb01 */ nb01,
+ /* nb02 */ nb02,
+ /* ne10 */ ne10,
+ /* ne11 */ 1,//ne11,
+ /* ne12 */ 1,//ne12,
+ /* ne13 */ 1,//ne13,
+ /* nb10 */ nb10,
+ /* nb11 */ nb11,
+ /* nb12 */ nb12,
+ /* ne0 */ ne0,
+ /* ne1 */ 1,//ne1,
+ /* nb1 */ nb1,
+ /* r2 */ 1,
+ /* r3 */ 1,
+ shared_values,
+ tgpig,
+ tiitg,
+ tiisg,
+ sgitg);
+}
+
+typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
+
+template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
+template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq1_bn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_bn_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_bn_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_bn_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
new file mode 100644
index 00000000..da4c9b9a
--- /dev/null
+++ b/ggml/src/ggml-quants.c
@@ -0,0 +1,14976 @@
+//
+// Copyright (C) 2023-2024 The ggml authors
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#include "ggml-quants.h"
+#include "ggml-impl.h"
+#if GGML_USE_IQK_MULMAT
+#include "iqk/iqk_mul_mat.h"
+#endif
+
+
+#include <math.h>
+#include <string.h>
+#include <assert.h>
+#include <float.h>
+#include <stdlib.h> // for qsort
+#include <stdio.h> // for GGML_ASSERT
+
+#define GROUP_MAX_EPS 1e-15f
+#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
+#define GROUP_MAX_EPS_IQ2_S 1e-8f
+#define GROUP_MAX_EPS_IQ1_M 1e-7f
+#define GROUP_MAX_EPS_IQ1_S 1e-12f
+
+#if defined(_MSC_VER)
+// disable "possible loss of data" to avoid warnings for hundreds of casts
+// we should just be careful :)
+#pragma warning(disable: 4244 4267)
+#endif
+
+#define UNUSED GGML_UNUSED
+
+// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+// multiply int8_t, add results pairwise twice
+static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
+ // Get absolute values of x vectors
+ const __m128i ax = _mm_sign_epi8(x, x);
+ // Sign the values of the y vectors
+ const __m128i sy = _mm_sign_epi8(y, x);
+ // Perform multiplication and create 16-bit values
+ const __m128i dot = _mm_maddubs_epi16(ax, sy);
+ const __m128i ones = _mm_set1_epi16(1);
+ return _mm_madd_epi16(ones, dot);
+}
+
+#if __AVX__ || __AVX2__ || __AVX512F__
+// horizontally add 8 floats
+static inline float hsum_float_8(const __m256 x) {
+ __m128 res = _mm256_extractf128_ps(x, 1);
+ res = _mm_add_ps(res, _mm256_castps256_ps128(x));
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
+ return _mm_cvtss_f32(res);
+}
+
+// horizontally add 8 int32_t
+static inline int hsum_i32_8(const __m256i a) {
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+// horizontally add 4 int32_t
+static inline int hsum_i32_4(const __m128i a) {
+ const __m128i hi64 = _mm_unpackhi_epi64(a, a);
+ const __m128i sum64 = _mm_add_epi32(hi64, a);
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+#if defined(__AVX2__) || defined(__AVX512F__)
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+ uint32_t x32;
+ memcpy(&x32, x, sizeof(uint32_t));
+ const __m256i shuf_mask = _mm256_set_epi64x(
+ 0x0303030303030303, 0x0202020202020202,
+ 0x0101010101010101, 0x0000000000000000);
+ __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
+ const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
+ bytes = _mm256_or_si256(bytes, bit_mask);
+ return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+ const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
+ const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
+ const __m256i lowMask = _mm256_set1_epi8( 0xF );
+ return _mm256_and_si256(lowMask, bytes);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m256i x) {
+ const __m256i ones = _mm256_set1_epi16(1);
+ const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
+ return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
+ const __m256i zero = _mm256_setzero_si256();
+ const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
+ return _mm256_cvtepi32_ps(summed_pairs);
+#else
+ // Perform multiplication and create 16-bit values
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+ return sum_i16_pairs_float(dot);
+#endif
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+#if __AVXVNNIINT8__
+ const __m256i zero = _mm256_setzero_si256();
+ const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
+ return _mm256_cvtepi32_ps(summed_pairs);
+#else
+ // Get absolute values of x vectors
+ const __m256i ax = _mm256_sign_epi8(x, x);
+ // Sign the values of the y vectors
+ const __m256i sy = _mm256_sign_epi8(y, x);
+ return mul_sum_us8_pairs_float(ax, sy);
+#endif
+}
+
+static inline __m128i packNibbles( __m256i bytes )
+{
+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+#if __AVX512F__
+ const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
+ bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
+ return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
+#else
+ const __m256i lowByte = _mm256_set1_epi16( 0xFF );
+ __m256i high = _mm256_andnot_si256( lowByte, bytes );
+ __m256i low = _mm256_and_si256( lowByte, bytes );
+ high = _mm256_srli_epi16( high, 4 );
+ bytes = _mm256_or_si256( low, high );
+
+ // Compress uint16_t lanes into bytes
+ __m128i r0 = _mm256_castsi256_si128( bytes );
+ __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
+ return _mm_packus_epi16( r0, r1 );
+#endif
+}
+#elif defined(__AVX__)
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+ uint32_t x32;
+ memcpy(&x32, x, sizeof(uint32_t));
+ const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
+ const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
+ __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
+ __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
+ const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
+ bytesl = _mm_or_si128(bytesl, bit_mask);
+ bytesh = _mm_or_si128(bytesh, bit_mask);
+ bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
+ bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
+ return MM256_SET_M128I(bytesh, bytesl);
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+ // Load 16 bytes from memory
+ __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
+ __m128i tmph = _mm_srli_epi16(tmpl, 4);
+ const __m128i lowMask = _mm_set1_epi8(0xF);
+ tmpl = _mm_and_si128(lowMask, tmpl);
+ tmph = _mm_and_si128(lowMask, tmph);
+ return MM256_SET_M128I(tmph, tmpl);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
+ const __m128i ones = _mm_set1_epi16(1);
+ const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
+ const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
+ const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
+ return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+ const __m128i axl = _mm256_castsi256_si128(ax);
+ const __m128i axh = _mm256_extractf128_si256(ax, 1);
+ const __m128i syl = _mm256_castsi256_si128(sy);
+ const __m128i syh = _mm256_extractf128_si256(sy, 1);
+ // Perform multiplication and create 16-bit values
+ const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+ const __m128i doth = _mm_maddubs_epi16(axh, syh);
+ return sum_i16_pairs_float(doth, dotl);
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+ const __m128i xl = _mm256_castsi256_si128(x);
+ const __m128i xh = _mm256_extractf128_si256(x, 1);
+ const __m128i yl = _mm256_castsi256_si128(y);
+ const __m128i yh = _mm256_extractf128_si256(y, 1);
+ // Get absolute values of x vectors
+ const __m128i axl = _mm_sign_epi8(xl, xl);
+ const __m128i axh = _mm_sign_epi8(xh, xh);
+ // Sign the values of the y vectors
+ const __m128i syl = _mm_sign_epi8(yl, xl);
+ const __m128i syh = _mm_sign_epi8(yh, xh);
+ // Perform multiplication and create 16-bit values
+ const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+ const __m128i doth = _mm_maddubs_epi16(axh, syh);
+ return sum_i16_pairs_float(doth, dotl);
+}
+
+static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
+{
+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+ const __m128i lowByte = _mm_set1_epi16( 0xFF );
+ __m128i high = _mm_andnot_si128( lowByte, bytes1 );
+ __m128i low = _mm_and_si128( lowByte, bytes1 );
+ high = _mm_srli_epi16( high, 4 );
+ bytes1 = _mm_or_si128( low, high );
+ high = _mm_andnot_si128( lowByte, bytes2 );
+ low = _mm_and_si128( lowByte, bytes2 );
+ high = _mm_srli_epi16( high, 4 );
+ bytes2 = _mm_or_si128( low, high );
+
+ return _mm_packus_epi16( bytes1, bytes2);
+}
+#endif
+#elif defined(__SSSE3__)
+// horizontally add 4x4 floats
+static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
+ __m128 res_0 =_mm_hadd_ps(a, b);
+ __m128 res_1 =_mm_hadd_ps(c, d);
+ __m128 res =_mm_hadd_ps(res_0, res_1);
+ res =_mm_hadd_ps(res, res);
+ res =_mm_hadd_ps(res, res);
+
+ return _mm_cvtss_f32(res);
+}
+#endif // __AVX__ || __AVX2__ || __AVX512F__
+#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+
+#if defined(__ARM_NEON) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__)
+#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
+#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
+#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
+#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
+#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
+#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
+#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
+#define B8(c,s ) B7(c,s, c), B7(c,s, s)
+
+// precomputed tables for expanding 8bits to 8 bytes:
+static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
+static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
+#endif
+
+#if defined(__loongarch_asx)
+
+#ifdef __clang__
+#define VREGS_PREFIX "$vr"
+#define XREGS_PREFIX "$xr"
+#else // GCC
+#define VREGS_PREFIX "$f"
+#define XREGS_PREFIX "$f"
+#endif
+#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
+// Convert __m128i to __m256i
+static inline __m256i ____m256i(__m128i in) {
+ __m256i out = __lasx_xvldi(0);
+ __asm__ volatile (
+ ".irp i," __ALL_REGS "\n\t"
+ " .ifc %[out], " XREGS_PREFIX"\\i \n\t"
+ " .irp j," __ALL_REGS "\n\t"
+ " .ifc %[in], " VREGS_PREFIX "\\j \n\t"
+ " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
+ " .endif \n\t"
+ " .endr \n\t"
+ " .endif \n\t"
+ ".endr \n\t"
+ : [out] "+f" (out) : [in] "f" (in)
+ );
+ return out;
+}
+// Convert two __m128i to __m256i
+static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
+ __m256i out;
+ __asm__ volatile (
+ ".irp i," __ALL_REGS "\n\t"
+ " .ifc %[hi], " VREGS_PREFIX "\\i \n\t"
+ " .irp j," __ALL_REGS "\n\t"
+ " .ifc %[lo], " VREGS_PREFIX "\\j \n\t"
+ " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
+ " .endif \n\t"
+ " .endr \n\t"
+ " .endif \n\t"
+ ".endr \n\t"
+ ".ifnc %[out], %[hi] \n\t"
+ ".irp i," __ALL_REGS "\n\t"
+ " .ifc %[out], " XREGS_PREFIX "\\i \n\t"
+ " .irp j," __ALL_REGS "\n\t"
+ " .ifc %[hi], " VREGS_PREFIX "\\j \n\t"
+ " xvori.b $xr\\i, $xr\\j, 0 \n\t"
+ " .endif \n\t"
+ " .endr \n\t"
+ " .endif \n\t"
+ ".endr \n\t"
+ ".endif \n\t"
+ : [out] "=f" (out), [hi] "+f" (inhi)
+ : [lo] "f" (inlo)
+ );
+ return out;
+}
+// Convert __m256i low part to __m128i
+static inline __m128i lasx_extracti128_lo(__m256i in) {
+ __m128i out;
+ __asm__ volatile (
+ ".ifnc %[out], %[in] \n\t"
+ ".irp i," __ALL_REGS "\n\t"
+ " .ifc %[out], " VREGS_PREFIX "\\i \n\t"
+ " .irp j," __ALL_REGS "\n\t"
+ " .ifc %[in], " XREGS_PREFIX "\\j \n\t"
+ " vori.b $vr\\i, $vr\\j, 0 \n\t"
+ " .endif \n\t"
+ " .endr \n\t"
+ " .endif \n\t"
+ ".endr \n\t"
+ ".endif \n\t"
+ : [out] "=f" (out) : [in] "f" (in)
+ );
+ return out;
+}
+// Convert __m256i high part to __m128i
+static inline __m128i lasx_extracti128_hi(__m256i in) {
+ __m128i out;
+ __asm__ volatile (
+ ".irp i," __ALL_REGS "\n\t"
+ " .ifc %[out], " VREGS_PREFIX "\\i \n\t"
+ " .irp j," __ALL_REGS "\n\t"
+ " .ifc %[in], " XREGS_PREFIX "\\j \n\t"
+ " xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t"
+ " .endif \n\t"
+ " .endr \n\t"
+ " .endif \n\t"
+ ".endr \n\t"
+ : [out] "=f" (out) : [in] "f" (in)
+ );
+ return out;
+}
+
+static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) {
+ v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7};
+ return (__m256i)__ret;
+}
+
+static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
+ v4i32 __ret = {d, c, b, a};
+ return (__m128i)__ret;
+}
+
+static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
+ v4i64 __ret = {d, c, b, a};
+ return (__m256i)__ret;
+}
+
+static __m256i lasx_insertf128( __m128i x, __m128i y) {
+ return lasx_set_q(x, y);
+}
+
+static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
+ __m128i mask_f, zero, tmp0, tmp2, mask;
+ int f = 0x8f;
+ mask_f = __lsx_vreplgr2vr_b(f);
+ zero = __lsx_vldi(0);
+ tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
+ tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
+ mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
+ tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
+ return __lsx_vshuf_b(a, zero, tmp2);
+}
+
+static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
+ __m256i mask_f, zero, tmp0, tmp2, mask;
+ int f = 0x8f;
+ mask_f = __lasx_xvreplgr2vr_b(f);
+ zero = __lasx_xvldi(0);
+ tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits
+ tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
+ mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask
+ tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones
+ return __lasx_xvshuf_b(a, zero, tmp2);
+}
+
+static __m256i lasx_extu8_16(__m128i a) {
+ __m128i zero = __lsx_vldi(0);
+ __m128i vlo = __lsx_vilvl_b(zero, a);
+ __m128i vhi = __lsx_vilvh_b(zero, a);
+ return lasx_set_q(vhi, vlo);
+}
+
+static __m256i lasx_ext8_16(__m128i a) {
+ __m128i sign = __lsx_vslti_b(a, 0);
+ __m128i vlo = __lsx_vilvl_b(sign, a);
+ __m128i vhi = __lsx_vilvh_b(sign, a);
+ return lasx_set_q(vhi, vlo);
+}
+
+static __m256i lasx_ext16_32(__m128i a) {
+ __m256i tmp1;
+ tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
+ tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
+ tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
+ tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
+ tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
+ tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
+ tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
+ tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
+ return tmp1;
+}
+
+static __m128i lasx_extracti128( __m256i a, int pos) {
+ __m128i ret;
+ if( pos == 0)
+ {
+ ret = lasx_extracti128_lo(a);
+ } else {
+ ret = lasx_extracti128_hi(a);
+ }
+ return ret;
+}
+
+static __m128 lasx_extractf128( __m256 a, int pos) {
+ __m128 ret;
+ if( pos == 0)
+ {
+ ret = (__m128)lasx_extracti128_lo((__m256i)a);
+ } else {
+ ret = (__m128)lasx_extracti128_hi((__m256i)a);
+ }
+ return ret;
+}
+
+static __m128i lsx_hadd_h(__m128i a, __m128i b) {
+ __m128i tmp1 = __lsx_vpickev_h(b, a);
+ __m128i tmp2 = __lsx_vpickod_h(b, a);
+ return __lsx_vadd_h(tmp1, tmp2);
+}
+
+static __m128i lsx_hadd_w(__m128i a, __m128i b) {
+ __m128i tmp1 = __lsx_vpickev_w(b, a);
+ __m128i tmp2 = __lsx_vpickod_w(b, a);
+ return __lsx_vadd_w(tmp1, tmp2);
+}
+
+static __m128 lsx_hadd_s(__m128 a, __m128 b) {
+ __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
+ __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
+
+ return __lsx_vfadd_s(tmp1, tmp2);
+}
+
+static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
+ __m256i tmp1, tmp2;
+ tmp1 = __lasx_xvmulwev_h_b(a, b);
+ tmp2 = __lasx_xvmulwod_h_b(a, b);
+ return __lasx_xvsadd_h(tmp1, tmp2);
+}
+
+static __m256i lasx_madd_h(__m256i a, __m256i b) {
+ __m256i tmp1, tmp2;
+ tmp1 = __lasx_xvmulwev_w_h(a, b);
+ tmp2 = __lasx_xvmulwod_w_h(a, b);
+ return __lasx_xvadd_w(tmp1, tmp2);
+}
+
+static __m256i lasx_packs_w(__m256i a, __m256i b) {
+ __m256i tmp, tmp1;
+ tmp = __lasx_xvsat_w(a, 15);
+ tmp1 = __lasx_xvsat_w(b, 15);
+ return __lasx_xvpickev_h(tmp1, tmp);
+}
+
+static __m256i lasx_packs_h(__m256i a, __m256i b) {
+ __m256i tmp, tmp1;
+ tmp = __lasx_xvsat_h(a, 7);
+ tmp1 = __lasx_xvsat_h(b, 7);
+ return __lasx_xvpickev_b(tmp1, tmp);
+}
+
+static __m128i lsx_packs_w(__m128i a, __m128i b) {
+ __m128i tmp, tmp1;
+ tmp = __lsx_vsat_w(a, 15);
+ tmp1 = __lsx_vsat_w(b, 15);
+ return __lsx_vpickev_h(tmp1, tmp);
+}
+
+static __m128i lsx_packs_h(__m128i a, __m128i b) {
+ __m128i tmp, tmp1;
+ tmp = __lsx_vsat_h(a, 7);
+ tmp1 = __lsx_vsat_h(b, 7);
+ return __lsx_vpickev_b(tmp1, tmp);
+}
+
+static __m128i lsx_packus_h(__m128i a, __m128i b) {
+ __m128i tmp, tmp1;
+ tmp = __lsx_vsat_hu(a, 7);
+ tmp1 = __lsx_vsat_hu(b, 7);
+ return __lsx_vpickev_b(tmp1, tmp);
+}
+
+
+static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
+ __m128i tmp1, tmp2;
+ tmp1 = __lsx_vmulwev_h_b(a, b);
+ tmp2 = __lsx_vmulwod_h_b(a, b);
+ return __lsx_vsadd_h(tmp1, tmp2);
+}
+
+static __m128i lsx_madd_h(__m128i a, __m128i b) {
+ __m128i tmp1, tmp2;
+ tmp1 = __lsx_vmulwev_w_h(a, b);
+ tmp2 = __lsx_vmulwod_w_h(a, b);
+ return __lsx_vadd_w(tmp1, tmp2);
+}
+
+// multiply int8_t, add results pairwise twice
+static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
+ // Get absolute values of x vectors
+ const __m128i ax = __lsx_vsigncov_b(x, x);
+ // Sign the values of the y vectors
+ const __m128i sy = __lsx_vsigncov_b(x, y);
+ // Perform multiplication and create 16-bit values
+ const __m128i dot = lsx_maddubs_h(ax, sy);
+ const __m128i ones = __lsx_vreplgr2vr_h(1);
+ return lsx_madd_h(ones, dot);
+}
+
+// horizontally add 8 floats
+static inline float hsum_float_8(const __m256 x) {
+ __m128 res = lasx_extractf128(x, 1);
+ ft_union tmp;
+ res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
+ res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
+ res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
+ tmp.i = __lsx_vpickve2gr_w(res, 0);
+ return tmp.f;
+}
+
+// horizontally add 8 int32_t
+static inline int hsum_i32_8(const __m256i a) {
+
+ __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);
+ __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);
+
+ __m128i tmp1_128 = lasx_extracti128_lo(tmp1);
+ __m128i tmp2_128 = lasx_extracti128_lo(tmp2);
+
+ __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);
+
+ __m128i ev = __lsx_vpickev_w(sum128, sum128);
+ __m128i od = __lsx_vpickod_w(sum128, sum128);
+ __m128i sum64 = __lsx_vadd_w(ev, od);
+
+ int sum64_1, sum64_2;
+ sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
+ sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
+
+ return sum64_1 + sum64_2;
+}
+
+// horizontally add 4 int32_t
+static inline int hsum_i32_4(const __m128i a) {
+ __m128i ev = __lsx_vpickev_w(a, a);
+ __m128i od = __lsx_vpickod_w(a, a);
+ __m128i sum64 = __lsx_vadd_w(ev, od);
+
+ int sum64_1, sum64_2;
+ sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
+ sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
+
+ return sum64_1 + sum64_2;
+}
+
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+
+ uint32_t x32;
+ memcpy(&x32, x, sizeof(uint32_t));
+ const __m256i shuf_mask = lasx_set_d(
+ 0x0303030303030303, 0x0202020202020202,
+ 0x0101010101010101, 0x0000000000000000);
+
+ __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask);
+ const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe);
+ bytes = __lasx_xvor_v(bytes, bit_mask);
+ return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1));
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
+ const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
+ __m128i hi = __lsx_vsrli_h(lo, 4);
+ return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m256i x) {
+ __m256i v = __lasx_xvpackod_h(x, x);
+ __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
+ return __lasx_xvffint_s_w(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+ // Perform multiplication and create 16-bit values
+ const __m256i dot = lasx_maddubs_h(ax, sy);
+ return sum_i16_pairs_float(dot);
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+
+ // Get absolute values of x vectors
+ const __m256i ax = __lasx_xvsigncov_b(x, x);
+ // Sign the values of the y vectors
+ const __m256i sy = __lasx_xvsigncov_b(x, y);
+
+ return mul_sum_us8_pairs_float(ax, sy);
+}
+
+static inline __m128i packNibbles( __m256i bytes ) {
+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+ const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF);
+ __m256i high = __lasx_xvandn_v(lowByte, bytes);
+ __m256i low = __lasx_xvand_v(lowByte, bytes);
+ high = __lasx_xvsrli_h(high, 4);
+ bytes = __lasx_xvor_v(low, high);
+ // Compress uint16_t lanes into bytes
+ __m128i *r0 = (__m128i *)&bytes;
+ __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11);
+ __m128i *r1 = (__m128i *)&tmp_h128;
+
+ __m128i zero = __lsx_vldi(0);
+ __m128i tmp, tmp2, tmp3;
+
+ tmp = __lsx_vmax_h(zero, *r0);
+ tmp2 = __lsx_vsat_hu(tmp, 7);
+
+ tmp = __lsx_vmax_h(zero, *r1);
+ tmp3 = __lsx_vsat_hu(tmp, 7);
+ return __lsx_vpickev_b(tmp3, tmp2);
+}
+#endif //__loongarch_asx
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
+ static const int qk = QK4_0;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < qk; j++) {
+ const float v = x[i*qk + j];
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -8;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = GGML_FP32_TO_FP16(d);
+
+ for (int j = 0; j < qk/2; ++j) {
+ const float x0 = x[i*qk + 0 + j]*id;
+ const float x1 = x[i*qk + qk/2 + j]*id;
+
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+ y[i].qs[j] = xi0;
+ y[i].qs[j] |= xi1 << 4;
+ }
+ }
+}
+
+void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) {
+ quantize_row_q4_0_ref(x, y, k);
+}
+
+
+void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, int64_t k) {
+ const int qk = QK4_1;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ float min = FLT_MAX;
+ float max = -FLT_MAX;
+
+ for (int j = 0; j < qk; j++) {
+ const float v = x[i*qk + j];
+
+ if (v < min) min = v;
+ if (v > max) max = v;
+ }
+
+ const float d = (max - min) / ((1 << 4) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = GGML_FP32_TO_FP16(d);
+ y[i].m = GGML_FP32_TO_FP16(min);
+
+ for (int j = 0; j < qk/2; ++j) {
+ const float x0 = (x[i*qk + 0 + j] - min)*id;
+ const float x1 = (x[i*qk + qk/2 + j] - min)*id;
+
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+ y[i].qs[j] = xi0;
+ y[i].qs[j] |= xi1 << 4;
+ }
+ }
+}
+
+void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) {
+ quantize_row_q4_1_ref(x, y, k);
+}
+
+void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, int64_t k) {
+ static const int qk = QK5_0;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < qk; j++) {
+ const float v = x[i*qk + j];
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -16;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = GGML_FP32_TO_FP16(d);
+
+ uint32_t qh = 0;
+
+ for (int j = 0; j < qk/2; ++j) {
+ const float x0 = x[i*qk + 0 + j]*id;
+ const float x1 = x[i*qk + qk/2 + j]*id;
+
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+ y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+ // get the 5-th bit and store it in qh at the right position
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
+ }
+
+ memcpy(&y[i].qh, &qh, sizeof(qh));
+ }
+}
+
+void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) {
+ quantize_row_q5_0_ref(x, y, k);
+}
+
+void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, int64_t k) {
+ const int qk = QK5_1;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ float min = FLT_MAX;
+ float max = -FLT_MAX;
+
+ for (int j = 0; j < qk; j++) {
+ const float v = x[i*qk + j];
+
+ if (v < min) min = v;
+ if (v > max) max = v;
+ }
+
+ const float d = (max - min) / ((1 << 5) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = GGML_FP32_TO_FP16(d);
+ y[i].m = GGML_FP32_TO_FP16(min);
+
+ uint32_t qh = 0;
+
+ for (int j = 0; j < qk/2; ++j) {
+ const float x0 = (x[i*qk + 0 + j] - min)*id;
+ const float x1 = (x[i*qk + qk/2 + j] - min)*id;
+
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+ y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+ // get the 5-th bit and store it in qh at the right position
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
+ }
+
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
+ }
+}
+
+void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) {
+ quantize_row_q5_1_ref(x, y, k);
+}
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
+ assert(k % QK8_0 == 0);
+ const int nb = k / QK8_0;
+
+ for (int i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_0; j++) {
+ const float v = x[i*QK8_0 + j];
+ amax = MAX(amax, fabsf(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = GGML_FP32_TO_FP16(d);
+
+ for (int j = 0; j < QK8_0; ++j) {
+ const float x0 = x[i*QK8_0 + j]*id;
+
+ y[i].qs[j] = roundf(x0);
+ }
+ }
+}
+
+void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) {
+ assert(QK8_0 == 32);
+ assert(k % QK8_0 == 0);
+ const int nb = k / QK8_0;
+
+ block_q8_0 * restrict y = vy;
+
+#if GGML_USE_IQK_MULMAT
+ const int nb4 = 4*(nb/4);
+#else
+ const int nb4 = -1;
+#endif
+#if defined(__ARM_NEON)
+ block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy;
+ for (int i = 0; i < nb; i++) {
+ int i4 = i/4, ir = i%4;
+ float32x4_t srcv [8];
+ float32x4_t asrcv[8];
+ float32x4_t amaxv[8];
+
+ for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+ const float amax = vmaxvq_f32(amaxv[0]);
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ if (i < nb4) {
+ y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
+ } else {
+ y[i].d = GGML_FP32_TO_FP16(d);
+ }
+
+ for (int j = 0; j < 8; j++) {
+ const float32x4_t v = vmulq_n_f32(srcv[j], id);
+ const int32x4_t vi = vcvtnq_s32_f32(v);
+
+ if (i < nb4) {
+ y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
+ y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
+ y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
+ y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
+ } else {
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+ }
+ }
+ }
+#elif defined(__wasm_simd128__)
+ for (int i = 0; i < nb; i++) {
+ v128_t srcv [8];
+ v128_t asrcv[8];
+ v128_t amaxv[8];
+
+ for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
+ for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
+ for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
+ for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
+
+ const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
+ wasm_f32x4_extract_lane(amaxv[0], 1)),
+ MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
+ wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = GGML_FP32_TO_FP16(d);
+
+ for (int j = 0; j < 8; j++) {
+ const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
+ const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
+
+ y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
+ y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
+ y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
+ y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
+ }
+ }
+#elif defined(__AVX2__) || defined(__AVX__)
+ block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy;
+#ifdef __AVX2__
+ const bool pack = true;
+#else
+ const bool pack = false;
+#endif
+ for (int i = 0; i < nb; i++) {
+ int i4 = i/4, ir = i%4;
+ // Load elements into 4 AVX vectors
+ __m256 v0 = _mm256_loadu_ps( x );
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
+ x += 32;
+
+ // Compute max(abs(e)) for the block
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+ const float maxScalar = _mm_cvtss_f32( max4 );
+
+ // Quantize these floats
+ const float d = maxScalar / 127.f;
+ if (pack && i < nb4) {
+ y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
+ } else {
+ y[i].d = GGML_FP32_TO_FP16(d);
+ }
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+ const __m256 mul = _mm256_set1_ps( id );
+
+ // Apply the multiplier
+ v0 = _mm256_mul_ps( v0, mul );
+ v1 = _mm256_mul_ps( v1, mul );
+ v2 = _mm256_mul_ps( v2, mul );
+ v3 = _mm256_mul_ps( v3, mul );
+
+ // Round to nearest integer
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+ // Convert floats to integers
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+ // Convert int32 to int16
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
+ // Convert int16 to int8
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+ // We got our precious signed bytes, but the order is now wrong
+ // These AVX2 pack instructions process 16-byte pieces independently
+ // The following instruction is fixing the order
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+ if (i < nb4) {
+ _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
+ } else {
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+ }
+#else
+ // Since we don't have in AVX some necessary functions,
+ // we split the registers in half and call AVX2 analogs from SSE
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+ // Convert int32 to int16
+ ni0 = _mm_packs_epi32( ni0, ni1 );
+ ni2 = _mm_packs_epi32( ni2, ni3 );
+ ni4 = _mm_packs_epi32( ni4, ni5 );
+ ni6 = _mm_packs_epi32( ni6, ni7 );
+ // Convert int16 to int8
+ ni0 = _mm_packs_epi16( ni0, ni2 );
+ ni4 = _mm_packs_epi16( ni4, ni6 );
+
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+ }
+#elif defined(__riscv_v_intrinsic)
+
+ size_t vl = __riscv_vsetvl_e32m4(QK8_0);
+
+ for (int i = 0; i < nb; i++) {
+ // load elements
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
+
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = GGML_FP32_TO_FP16(d);
+
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+ // convert to integer
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+ // store result
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+ }
+
+#elif defined(__POWER9_VECTOR__)
+ for (int i = 0; i < nb; i++) {
+ vector float srcv [8];
+ vector float asrcv[8];
+ vector float amaxv[8];
+ vector signed int vi[8];
+
+ for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
+ for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
+
+ const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
+ vec_extract(amaxv[0], 1)),
+ MAX(vec_extract(amaxv[0], 2),
+ vec_extract(amaxv[0], 3)));
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+ const vector float vid = vec_splats(id);
+
+ y[i].d = GGML_FP32_TO_FP16(d);
+
+ for (int j = 0; j < 8; j++) {
+ const vector float v = vec_round(vec_mul(srcv[j], vid));
+ vi[j] = vec_cts(v, 0);
+ }
+ vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]);
+ vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]);
+ }
+
+#elif defined(__loongarch_asx)
+ for (int i = 0; i < nb; i++) {
+ ft_union fi;
+ __m256 v0 = (__m256)__lasx_xvld( x , 0);
+ __m256 v1 = (__m256)__lasx_xvld( x , 32);
+ __m256 v2 = (__m256)__lasx_xvld( x , 64);
+ __m256 v3 = (__m256)__lasx_xvld( x , 96);
+ x += 32;
+
+ // Compute max(abs(e)) for the block
+ const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );
+ __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );
+ max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );
+ max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );
+ max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );
+
+ __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs , 0) );
+ max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
+ __m128 tmp = max4;
+ max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
+ fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
+ const float max_scalar = fi.f;
+
+ // Quantize these floats
+ const float d = max_scalar / 127.f;
+ y[i].d = GGML_FP32_TO_FP16(d);
+ const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
+ const __m256 mul = (__m256)__lasx_xvreplfr2vr_s( id );
+
+ // Apply the multiplier
+ v0 = __lasx_xvfmul_s( v0, mul );
+ v1 = __lasx_xvfmul_s( v1, mul );
+ v2 = __lasx_xvfmul_s( v2, mul );
+ v3 = __lasx_xvfmul_s( v3, mul );
+
+ // Round to nearest integer
+ __m256i i0 = __lasx_xvftintrne_w_s( v0 );
+ __m256i i1 = __lasx_xvftintrne_w_s( v1 );
+ __m256i i2 = __lasx_xvftintrne_w_s( v2 );
+ __m256i i3 = __lasx_xvftintrne_w_s( v3 );
+
+ __m128i ni0 = lasx_extracti128( i0, 0 );
+ __m128i ni1 = lasx_extracti128( i0, 1);
+ __m128i ni2 = lasx_extracti128( i1, 0);
+ __m128i ni3 = lasx_extracti128( i1, 1);
+ __m128i ni4 = lasx_extracti128( i2, 0);
+ __m128i ni5 = lasx_extracti128( i2, 1);
+ __m128i ni6 = lasx_extracti128( i3, 0);
+ __m128i ni7 = lasx_extracti128( i3, 1);
+
+ // Convert int32 to int16
+ ni0 = lsx_packs_w( ni0, ni1 );
+ ni2 = lsx_packs_w( ni2, ni3 );
+ ni4 = lsx_packs_w( ni4, ni5 );
+ ni6 = lsx_packs_w( ni6, ni7 );
+ // Convert int16 to int8
+ ni0 = lsx_packs_h( ni0, ni2 );
+ ni4 = lsx_packs_h( ni4, ni6 );
+
+ __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0);
+ __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
+
+ }
+#else
+ GGML_UNUSED(nb);
+ // scalar
+ quantize_row_q8_0_ref(x, y, k);
+#endif
+}
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) {
+ assert(QK8_1 == 32);
+ assert(k % QK8_1 == 0);
+ const int nb = k / QK8_1;
+
+ for (int i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_1; j++) {
+ const float v = x[i*QK8_1 + j];
+ amax = MAX(amax, fabsf(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = GGML_FP32_TO_FP16(d);
+
+ int sum = 0;
+
+ for (int j = 0; j < QK8_1/2; ++j) {
+ const float v0 = x[i*QK8_1 + j]*id;
+ const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;
+
+ y[i].qs[ j] = roundf(v0);
+ y[i].qs[QK8_1/2 + j] = roundf(v1);
+
+ sum += y[i].qs[ j];
+ sum += y[i].qs[QK8_1/2 + j];
+ }
+
+ y[i].s = GGML_FP32_TO_FP16(sum*d);
+ }
+}
+
+void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) {
+ assert(k % QK8_1 == 0);
+ const int nb = k / QK8_1;
+
+ block_q8_1 * restrict y = vy;
+
+#if GGML_USE_IQK_MULMAT
+ const int nb4 = 4*(nb/4);
+#else
+ const int nb4 = -1;
+#endif
+#if defined(__ARM_NEON)
+ block_q8_1_x4 * restrict y4 = vy;
+ for (int i = 0; i < nb; i++) {
+ int i4 = i/4, ir = i%4;
+ float32x4_t srcv [8];
+ float32x4_t asrcv[8];
+ float32x4_t amaxv[8];
+
+ for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+ const float amax = vmaxvq_f32(amaxv[0]);
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ if (i < nb4) {
+ y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
+ } else {
+ y[i].d = GGML_FP32_TO_FP16(d);
+ }
+
+ int32x4_t accv = vdupq_n_s32(0);
+
+ for (int j = 0; j < 8; j++) {
+ const float32x4_t v = vmulq_n_f32(srcv[j], id);
+ const int32x4_t vi = vcvtnq_s32_f32(v);
+
+ if (i < nb4) {
+ y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0);
+ y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1);
+ y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2);
+ y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3);
+ } else {
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+ }
+
+ accv = vaddq_s32(accv, vi);
+ }
+
+ if (i < nb4) {
+ y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
+ } else {
+ y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
+ }
+ }
+#elif defined(__wasm_simd128__)
+ for (int i = 0; i < nb; i++) {
+ v128_t srcv [8];
+ v128_t asrcv[8];
+ v128_t amaxv[8];
+
+ for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
+ for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
+ for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
+ for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
+
+ const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
+ wasm_f32x4_extract_lane(amaxv[0], 1)),
+ MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
+ wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = GGML_FP32_TO_FP16(d);
+
+ v128_t accv = wasm_i32x4_splat(0);
+
+ for (int j = 0; j < 8; j++) {
+ const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
+ const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
+
+ y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
+ y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
+ y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
+ y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
+
+ accv = wasm_i32x4_add(accv, vi);
+ }
+
+ y[i].s = GGML_FP32_TO_FP16(
+ d * (wasm_i32x4_extract_lane(accv, 0) +
+ wasm_i32x4_extract_lane(accv, 1) +
+ wasm_i32x4_extract_lane(accv, 2) +
+ wasm_i32x4_extract_lane(accv, 3)));
+ }
+#elif defined(__AVX2__) || defined(__AVX__)
+ block_q8_1_x4 * restrict y4 = vy;
+#ifdef __AVX2__
+ const bool pack = true;
+#else
+ const bool pack = false;
+#endif
+ for (int i = 0; i < nb; i++) {
+ int i4 = i/4, ir = i%4;
+ // Load elements into 4 AVX vectors
+ __m256 v0 = _mm256_loadu_ps( x );
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
+ x += 32;
+
+ // Compute max(abs(e)) for the block
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+ const float max_scalar = _mm_cvtss_f32( max4 );
+
+ // Quantize these floats
+ const float d = max_scalar / 127.f;
+ if (pack && i < nb4) {
+ y4[i4].d[ir] = GGML_FP32_TO_FP16(d);
+ } else {
+ y[i].d = GGML_FP32_TO_FP16(d);
+ }
+ const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
+ const __m256 mul = _mm256_set1_ps( id );
+
+ // Apply the multiplier
+ v0 = _mm256_mul_ps( v0, mul );
+ v1 = _mm256_mul_ps( v1, mul );
+ v2 = _mm256_mul_ps( v2, mul );
+ v3 = _mm256_mul_ps( v3, mul );
+
+ // Round to nearest integer
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+ // Convert floats to integers
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+ // Compute the sum of the quants and set y[i].s
+ if (i < nb4) {
+ y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
+ } else {
+ y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
+ }
+
+ // Convert int32 to int16
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
+ // Convert int16 to int8
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+ // We got our precious signed bytes, but the order is now wrong
+ // These AVX2 pack instructions process 16-byte pieces independently
+ // The following instruction is fixing the order
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+ if (i < nb4) {
+ _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0);
+ } else {
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+ }
+#else
+ // Since we don't have in AVX some necessary functions,
+ // we split the registers in half and call AVX2 analogs from SSE
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+ // Compute the sum of the quants and set y[i].s
+ const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
+ const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
+ y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1)));
+
+ // Convert int32 to int16
+ ni0 = _mm_packs_epi32( ni0, ni1 );
+ ni2 = _mm_packs_epi32( ni2, ni3 );
+ ni4 = _mm_packs_epi32( ni4, ni5 );
+ ni6 = _mm_packs_epi32( ni6, ni7 );
+ // Convert int16 to int8
+ ni0 = _mm_packs_epi16( ni0, ni2 );
+ ni4 = _mm_packs_epi16( ni4, ni6 );
+
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+ }
+#elif defined(__riscv_v_intrinsic)
+
+ size_t vl = __riscv_vsetvl_e32m4(QK8_1);
+
+ for (int i = 0; i < nb; i++) {
+ // load elements
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
+
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = GGML_FP32_TO_FP16(d);
+
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+ // convert to integer
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+ // store result
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+
+ // compute sum for y[i].s
+ vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
+ vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
+
+ // set y[i].s
+ int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
+ y[i].s = GGML_FP32_TO_FP16(sum*d);
+ }
+
+#elif defined(__POWER9_VECTOR__)
+ for (int i = 0; i < nb; i++) {
+ vector float srcv [8];
+ vector float asrcv[8];
+ vector float amaxv[8];
+ vector signed int vi[8];
+
+ for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
+ for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
+
+ const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
+ vec_extract(amaxv[0], 1)),
+ MAX(vec_extract(amaxv[0], 2),
+ vec_extract(amaxv[0], 3)));
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+ const vector float vid = vec_splats(id);
+
+ y[i].d = GGML_FP32_TO_FP16(d);
+
+ vector int accv = vec_splats(0);
+
+ for (int j = 0; j < 8; j++) {
+ const vector float v = vec_round(vec_mul(srcv[j], vid));
+ vi[j] = vec_cts(v, 0);
+
+ accv = vec_add(accv, vi[j]);
+ }
+ vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]);
+ vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]);
+
+ accv = vec_add(accv, vec_sld(accv, accv, 4));
+ accv = vec_add(accv, vec_sld(accv, accv, 8));
+ y[i].s = GGML_FP32_TO_FP16(d * vec_extract(accv, 0));
+ }
+
+#elif defined(__loongarch_asx)
+ for (int i = 0; i < nb; i++) {
+ ft_union ft;
+ __m256 v0 = (__m256)__lasx_xvld( x , 0 );
+ __m256 v1 = (__m256)__lasx_xvld( x , 32 );
+ __m256 v2 = (__m256)__lasx_xvld( x , 64 );
+ __m256 v3 = (__m256)__lasx_xvld( x , 96 );
+ x += 32;
+
+ // Compute max(abs(e)) for the block
+ const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );
+ __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );
+ max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );
+ max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );
+ max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );
+
+ __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) );
+ max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
+ __m128 tmp = max4;
+ max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
+ ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
+ const float max_scalar = ft.f;
+
+ // Quantize these floats
+ const float d = max_scalar / 127.f;
+ y[i].d = GGML_FP32_TO_FP16(d);
+ const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
+ const __m256 mul = __lasx_xvreplfr2vr_s( id );
+
+ // Apply the multiplier
+ v0 = __lasx_xvfmul_s( v0, mul );
+ v1 = __lasx_xvfmul_s( v1, mul );
+ v2 = __lasx_xvfmul_s( v2, mul );
+ v3 = __lasx_xvfmul_s( v3, mul );
+
+ // Round to nearest integer
+ __m256i i0 = __lasx_xvftintrne_w_s( v0 );
+ __m256i i1 = __lasx_xvftintrne_w_s( v1 );
+ __m256i i2 = __lasx_xvftintrne_w_s( v2 );
+ __m256i i3 = __lasx_xvftintrne_w_s( v3 );
+
+ __m128i ni0 = lasx_extracti128(i0, 0);
+ __m128i ni1 = lasx_extracti128( i0, 1);
+ __m128i ni2 = lasx_extracti128( i1, 0);
+ __m128i ni3 = lasx_extracti128( i1, 1);
+ __m128i ni4 = lasx_extracti128( i2, 0 );
+ __m128i ni5 = lasx_extracti128( i2, 1);
+ __m128i ni6 = lasx_extracti128( i3, 0);
+ __m128i ni7 = lasx_extracti128( i3, 1);
+
+ // Compute the sum of the quants and set y[i].s
+ const __m128i s0 = __lsx_vadd_w(__lsx_vadd_w(ni0, ni1), __lsx_vadd_w(ni2, ni3));
+ const __m128i s1 = __lsx_vadd_w(__lsx_vadd_w(ni4, ni5), __lsx_vadd_w(ni6, ni7));
+ y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(__lsx_vadd_w(s0, s1)));
+
+ // Convert int32 to int16
+ ni0 = lsx_packs_w( ni0, ni1 );
+ ni2 = lsx_packs_w( ni2, ni3 );
+ ni4 = lsx_packs_w( ni4, ni5 );
+ ni6 = lsx_packs_w( ni6, ni7 );
+ // Convert int16 to int8
+ ni0 = lsx_packs_h( ni0, ni2 );
+ ni4 = lsx_packs_h( ni4, ni6 );
+
+ __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0);
+ __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
+ }
+#else
+ GGML_UNUSED(nb);
+ // scalar
+ quantize_row_q8_1_ref(x, y, k);
+#endif
+}
+
+void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) {
+ static const int qk = QK4_0;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+
+ for (int j = 0; j < qk/2; ++j) {
+ const int x0 = (x[i].qs[j] & 0x0F) - 8;
+ const int x1 = (x[i].qs[j] >> 4) - 8;
+
+ y[i*qk + j + 0 ] = x0*d;
+ y[i*qk + j + qk/2] = x1*d;
+ }
+ }
+}
+
+void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int64_t k) {
+ static const int qk = QK4_1;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+ const float m = GGML_FP16_TO_FP32(x[i].m);
+
+ for (int j = 0; j < qk/2; ++j) {
+ const int x0 = (x[i].qs[j] & 0x0F);
+ const int x1 = (x[i].qs[j] >> 4);
+
+ y[i*qk + j + 0 ] = x0*d + m;
+ y[i*qk + j + qk/2] = x1*d + m;
+ }
+ }
+}
+
+void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int64_t k) {
+ static const int qk = QK5_0;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+
+ uint32_t qh;
+ memcpy(&qh, x[i].qh, sizeof(qh));
+
+ for (int j = 0; j < qk/2; ++j) {
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
+
+ const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
+ const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
+
+ y[i*qk + j + 0 ] = x0*d;
+ y[i*qk + j + qk/2] = x1*d;
+ }
+ }
+}
+
+void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int64_t k) {
+ static const int qk = QK5_1;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+ const float m = GGML_FP16_TO_FP32(x[i].m);
+
+ uint32_t qh;
+ memcpy(&qh, x[i].qh, sizeof(qh));
+
+ for (int j = 0; j < qk/2; ++j) {
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
+
+ const int x0 = (x[i].qs[j] & 0x0F) | xh_0;
+ const int x1 = (x[i].qs[j] >> 4) | xh_1;
+
+ y[i*qk + j + 0 ] = x0*d + m;
+ y[i*qk + j + qk/2] = x1*d + m;
+ }
+ }
+}
+
+void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int64_t k) {
+ static const int qk = QK8_0;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+
+ for (int j = 0; j < qk; ++j) {
+ y[i*qk + j] = x[i].qs[j]*d;
+ }
+ }
+}
+
+//
+// 2-6 bit quantization in super-blocks
+//
+
+//
+// ===================== Helper functions
+//
+static inline int nearest_int(float fval) {
+ assert(fval <= 4194303.f);
+ float val = fval + 12582912.f;
+ int i; memcpy(&i, &val, sizeof(int));
+ return (i & 0x007fffff) - 0x00400000;
+}
+
+static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type,
+ const float * restrict qw) {
+ float max = 0;
+ float amax = 0;
+ for (int i = 0; i < n; ++i) {
+ float ax = fabsf(x[i]);
+ if (ax > amax) { amax = ax; max = x[i]; }
+ }
+ if (amax < GROUP_MAX_EPS) { // all zero
+ for (int i = 0; i < n; ++i) {
+ L[i] = 0;
+ }
+ return 0.f;
+ }
+ float iscale = -nmax / max;
+ if (rmse_type == 0) {
+ for (int i = 0; i < n; ++i) {
+ int l = nearest_int(iscale * x[i]);
+ L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+ }
+ return 1/iscale;
+ }
+ bool return_early = false;
+ if (rmse_type < 0) {
+ rmse_type = -rmse_type;
+ return_early = true;
+ }
+ float sumlx = 0;
+ float suml2 = 0;
+#ifdef HAVE_BUGGY_APPLE_LINKER
+ // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
+ for (volatile int i = 0; i < n; ++i) {
+#else
+ for (int i = 0; i < n; ++i) {
+#endif
+ int l = nearest_int(iscale * x[i]);
+ l = MAX(-nmax, MIN(nmax-1, l));
+ L[i] = l + nmax;
+ float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
+ sumlx += w*x[i]*l;
+ suml2 += w*l*l;
+ }
+ float scale = suml2 ? sumlx/suml2 : 0.0f;
+ if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
+ float best = scale * sumlx;
+ for (int is = -9; is <= 9; ++is) {
+ if (is == 0) {
+ continue;
+ }
+ iscale = -(nmax + 0.1f*is) / max;
+ sumlx = suml2 = 0;
+ for (int i = 0; i < n; ++i) {
+ int l = nearest_int(iscale * x[i]);
+ l = MAX(-nmax, MIN(nmax-1, l));
+ float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
+ sumlx += w*x[i]*l;
+ suml2 += w*l*l;
+ }
+ if (suml2 > 0 && sumlx*sumlx > best*suml2) {
+ for (int i = 0; i < n; ++i) {
+ int l = nearest_int(iscale * x[i]);
+ L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+ }
+ scale = sumlx/suml2; best = scale*sumlx;
+ }
+ }
+ return scale;
+}
+
+static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
+ float max = 0;
+ float amax = 0;
+ for (int i = 0; i < n; ++i) {
+ float ax = fabsf(x[i]);
+ if (ax > amax) { amax = ax; max = x[i]; }
+ }
+ if (amax < GROUP_MAX_EPS) { // all zero
+ for (int i = 0; i < n; ++i) { L[i] = 0; }
+ return 0.f;
+ }
+ float iscale = -nmax / max;
+ if (do_rmse) {
+ float sumlx = 0;
+ float suml2 = 0;
+ for (int i = 0; i < n; ++i) {
+ int l = nearest_int(iscale * x[i]);
+ l = MAX(-nmax, MIN(nmax-1, l));
+ L[i] = l;
+ float w = x[i]*x[i];
+ sumlx += w*x[i]*l;
+ suml2 += w*l*l;
+ }
+ for (int itry = 0; itry < 5; ++itry) {
+ int n_changed = 0;
+ for (int i = 0; i < n; ++i) {
+ float w = x[i]*x[i];
+ float slx = sumlx - w*x[i]*L[i];
+ if (slx > 0) {
+ float sl2 = suml2 - w*L[i]*L[i];
+ int new_l = nearest_int(x[i] * sl2 / slx);
+ new_l = MAX(-nmax, MIN(nmax-1, new_l));
+ if (new_l != L[i]) {
+ slx += w*x[i]*new_l;
+ sl2 += w*new_l*new_l;
+ if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
+ L[i] = new_l; sumlx = slx; suml2 = sl2;
+ ++n_changed;
+ }
+ }
+ }
+ }
+ if (!n_changed) {
+ break;
+ }
+ }
+ for (int i = 0; i < n; ++i) {
+ L[i] += nmax;
+ }
+ return sumlx / suml2;
+ }
+ for (int i = 0; i < n; ++i) {
+ int l = nearest_int(iscale * x[i]);
+ l = MAX(-nmax, MIN(nmax-1, l));
+ L[i] = l + nmax;
+ }
+ return 1/iscale;
+}
+
+static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
+ int ntry, float alpha) {
+ float min = x[0];
+ float max = x[0];
+ for (int i = 1; i < n; ++i) {
+ if (x[i] < min) min = x[i];
+ if (x[i] > max) max = x[i];
+ }
+ if (max == min) {
+ for (int i = 0; i < n; ++i) L[i] = 0;
+ *the_min = 0;
+ return 0.f;
+ }
+ if (min > 0) min = 0;
+ float iscale = nmax/(max - min);
+ float scale = 1/iscale;
+ for (int itry = 0; itry < ntry; ++itry) {
+ float sumlx = 0; int suml2 = 0;
+ bool did_change = false;
+ for (int i = 0; i < n; ++i) {
+ int l = nearest_int(iscale*(x[i] - min));
+ l = MAX(0, MIN(nmax, l));
+ if (l != L[i]) {
+ L[i] = l;
+ did_change = true;
+ }
+ sumlx += (x[i] - min)*l;
+ suml2 += l*l;
+ }
+ scale = sumlx/suml2;
+ float sum = 0;
+ for (int i = 0; i < n; ++i) {
+ sum += x[i] - scale*L[i];
+ }
+ min = alpha*min + (1 - alpha)*sum/n;
+ if (min > 0) min = 0;
+ iscale = 1/scale;
+ if (!did_change) break;
+ }
+ *the_min = -min;
+ return scale;
+}
+
+static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
+ uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
+ float rmin, float rdelta, int nstep, bool use_mad) {
+ float min = x[0];
+ float max = x[0];
+ float sum_w = weights[0];
+ float sum_x = sum_w * x[0];
+#ifdef HAVE_BUGGY_APPLE_LINKER
+ // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
+ for (volatile int i = 1; i < n; ++i) {
+#else
+ for (int i = 1; i < n; ++i) {
+#endif
+ if (x[i] < min) min = x[i];
+ if (x[i] > max) max = x[i];
+ float w = weights[i];
+ sum_w += w;
+ sum_x += w * x[i];
+ }
+ if (min > 0) min = 0;
+ if (max == min) {
+ for (int i = 0; i < n; ++i) L[i] = 0;
+ *the_min = -min;
+ return 0.f;
+ }
+ float iscale = nmax/(max - min);
+ float scale = 1/iscale;
+ float best_mad = 0;
+ for (int i = 0; i < n; ++i) {
+ int l = nearest_int(iscale*(x[i] - min));
+ L[i] = MAX(0, MIN(nmax, l));
+ float diff = scale * L[i] + min - x[i];
+ diff = use_mad ? fabsf(diff) : diff * diff;
+ float w = weights[i];
+ best_mad += w * diff;
+ }
+ if (nstep < 1) {
+ *the_min = -min;
+ return scale;
+ }
+ for (int is = 0; is <= nstep; ++is) {
+ iscale = (rmin + rdelta*is + nmax)/(max - min);
+ float sum_l = 0, sum_l2 = 0, sum_xl = 0;
+ for (int i = 0; i < n; ++i) {
+ int l = nearest_int(iscale*(x[i] - min));
+ l = MAX(0, MIN(nmax, l));
+ Laux[i] = l;
+ float w = weights[i];
+ sum_l += w*l;
+ sum_l2 += w*l*l;
+ sum_xl += w*l*x[i];
+ }
+ float D = sum_w * sum_l2 - sum_l * sum_l;
+ if (D > 0) {
+ float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
+ float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
+ if (this_min > 0) {
+ this_min = 0;
+ this_scale = sum_xl / sum_l2;
+ }
+ float mad = 0;
+ for (int i = 0; i < n; ++i) {
+ float diff = this_scale * Laux[i] + this_min - x[i];
+ diff = use_mad ? fabsf(diff) : diff * diff;
+ float w = weights[i];
+ mad += w * diff;
+ }
+ if (mad < best_mad) {
+ for (int i = 0; i < n; ++i) {
+ L[i] = Laux[i];
+ }
+ best_mad = mad;
+ scale = this_scale;
+ min = this_min;
+ }
+ }
+ }
+ *the_min = -min;
+ return scale;
+}
+
+static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
+ if (j < 4) {
+ *d = q[j] & 63; *m = q[j + 4] & 63;
+ } else {
+ *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+ *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
+ }
+}
+
+//========================- 2-bit (de)-quantization
+
+void quantize_row_q2_K_ref(const float * restrict x, block_q2_K * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int nb = k / QK_K;
+
+ uint8_t L[QK_K];
+ uint8_t Laux[16];
+ float weights[16];
+ float mins[QK_K/16];
+ float scales[QK_K/16];
+
+ const float q4scale = 15.f;
+
+ for (int i = 0; i < nb; i++) {
+ float max_scale = 0; // as we are deducting the min, scales are always positive
+ float max_min = 0;
+ for (int j = 0; j < QK_K/16; ++j) {
+ for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
+ scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
+ float scale = scales[j];
+ if (scale > max_scale) {
+ max_scale = scale;
+ }
+ float min = mins[j];
+ if (min > max_min) {
+ max_min = min;
+ }
+ }
+
+ if (max_scale > 0) {
+ float iscale = q4scale/max_scale;
+ for (int j = 0; j < QK_K/16; ++j) {
+ int l = nearest_int(iscale*scales[j]);
+ y[i].scales[j] = l;
+ }
+ y[i].d = GGML_FP32_TO_FP16(max_scale/q4scale);
+ } else {
+ for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
+ y[i].d = GGML_FP32_TO_FP16(0.f);
+ }
+ if (max_min > 0) {
+ float iscale = q4scale/max_min;
+ for (int j = 0; j < QK_K/16; ++j) {
+ int l = nearest_int(iscale*mins[j]);
+ y[i].scales[j] |= (l << 4);
+ }
+ y[i].dmin = GGML_FP32_TO_FP16(max_min/q4scale);
+ } else {
+ y[i].dmin = GGML_FP32_TO_FP16(0.f);
+ }
+ for (int j = 0; j < QK_K/16; ++j) {
+ const float d = GGML_FP16_TO_FP32(y[i].d) * (y[i].scales[j] & 0xF);
+ if (!d) continue;
+ const float dm = GGML_FP16_TO_FP32(y[i].dmin) * (y[i].scales[j] >> 4);
+ for (int ii = 0; ii < 16; ++ii) {
+ int l = nearest_int((x[16*j + ii] + dm)/d);
+ l = MAX(0, MIN(3, l));
+ L[16*j + ii] = l;
+ }
+ }
+
+ for (int j = 0; j < QK_K; j += 128) {
+ for (int l = 0; l < 32; ++l) {
+ y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+ }
+ }
+
+ x += QK_K;
+ }
+}
+
+void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int nb = k / QK_K;
+
+ for (int i = 0; i < nb; i++) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+ const float min = GGML_FP16_TO_FP32(x[i].dmin);
+
+ const uint8_t * q = x[i].qs;
+
+ int is = 0;
+ float dl, ml;
+ for (int n = 0; n < QK_K; n += 128) {
+ int shift = 0;
+ for (int j = 0; j < 4; ++j) {
+
+ uint8_t sc = x[i].scales[is++];
+ dl = d * (sc & 0xF); ml = min * (sc >> 4);
+ for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
+
+ sc = x[i].scales[is++];
+ dl = d * (sc & 0xF); ml = min * (sc >> 4);
+ for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
+
+ shift += 2;
+ }
+ q += 32;
+ }
+ }
+}
+
+void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) {
+ quantize_row_q2_K_ref(x, vy, k);
+}
+
+static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
+ uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
+ float rmin, float rdelta, int nstep, bool use_mad) {
+ float min = x[0];
+ float max = x[0];
+ float sum_w = weights ? weights[0] : x[0]*x[0];
+ float sum_x = sum_w * x[0];
+#ifdef HAVE_BUGGY_APPLE_LINKER
+ // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
+ for (volatile int i = 1; i < n; ++i) {
+#else
+ for (int i = 1; i < n; ++i) {
+#endif
+ if (x[i] < min) min = x[i];
+ if (x[i] > max) max = x[i];
+ float w = weights ? weights[i] : x[i]*x[i];
+ sum_w += w;
+ sum_x += w * x[i];
+ }
+ if (min > 0) {
+ min = 0;
+ }
+ if (max <= min) {
+ memset(L, 0, n);
+ *the_min = -min;
+ return 0.f;
+ }
+ float iscale = nmax/(max - min);
+ float scale = 1/iscale;
+ float best_mad = 0;
+ for (int i = 0; i < n; ++i) {
+ int l = nearest_int(iscale*(x[i] - min));
+ L[i] = MAX(0, MIN(nmax, l));
+ float diff = scale * L[i] + min - x[i];
+ diff = use_mad ? fabsf(diff) : diff*diff;
+ float w = weights ? weights[i] : x[i]*x[i];
+ best_mad += w * diff;
+ }
+ if (nstep < 1) {
+ *the_min = -min;
+ return scale;
+ }
+ for (int is = 0; is <= nstep; ++is) {
+ iscale = (rmin + rdelta*is + nmax)/(max - min);
+ float sum_l = 0, sum_l2 = 0, sum_xl = 0;
+ for (int i = 0; i < n; ++i) {
+ int l = nearest_int(iscale*(x[i] - min));
+ l = MAX(0, MIN(nmax, l));
+ Laux[i] = l;
+ float w = weights ? weights[i] : x[i]*x[i];
+ sum_l += w*l;
+ sum_l2 += w*l*l;
+ sum_xl += w*l*x[i];
+ }
+ float D = sum_w * sum_l2 - sum_l * sum_l;
+ if (D > 0) {
+ float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
+ float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
+ if (this_min > 0) {
+ this_min = 0;
+ this_scale = sum_xl / sum_l2;
+ }
+ float mad = 0;
+ for (int i = 0; i < n; ++i) {
+ float diff = this_scale * Laux[i] + this_min - x[i];
+ diff = use_mad ? fabsf(diff) : diff*diff;
+ float w = weights ? weights[i] : x[i]*x[i];
+ mad += w * diff;
+ }
+ if (mad < best_mad) {
+ for (int i = 0; i < n; ++i) {
+ L[i] = Laux[i];
+ }
+ best_mad = mad;
+ scale = this_scale;
+ min = this_min;
+ }
+ }
+ }
+ *the_min = -min;
+ return scale;
+}
+
+static float make_qp_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, const float * quant_weights) {
+ float max = 0;
+ for (int i = 0; i < n; ++i) {
+ max = MAX(max, x[i]);
+ }
+ if (!max) { // all zero
+ for (int i = 0; i < n; ++i) { L[i] = 0; }
+ return 0.f;
+ }
+ float iscale = nmax / max;
+ for (int i = 0; i < n; ++i) {
+ L[i] = nearest_int(iscale * x[i]);
+ }
+ float scale = 1/iscale;
+ float best_mse = 0;
+ for (int i = 0; i < n; ++i) {
+ float diff = x[i] - scale*L[i];
+ float w = quant_weights[i];
+ best_mse += w*diff*diff;
+ }
+ for (int is = -4; is <= 4; ++is) {
+ if (is == 0) continue;
+ float iscale_is = (0.1f*is + nmax)/max;
+ float scale_is = 1/iscale_is;
+ float mse = 0;
+ for (int i = 0; i < n; ++i) {
+ int l = nearest_int(iscale_is*x[i]);
+ l = MIN(nmax, l);
+ float diff = x[i] - scale_is*l;
+ float w = quant_weights[i];
+ mse += w*diff*diff;
+ }
+ if (mse < best_mse) {
+ best_mse = mse;
+ iscale = iscale_is;
+ }
+ }
+ float sumlx = 0;
+ float suml2 = 0;
+ for (int i = 0; i < n; ++i) {
+ int l = nearest_int(iscale * x[i]);
+ l = MIN(nmax, l);
+ L[i] = l;
+ float w = quant_weights[i];
+ sumlx += w*x[i]*l;
+ suml2 += w*l*l;
+ }
+ for (int itry = 0; itry < 5; ++itry) {
+ int n_changed = 0;
+ for (int i = 0; i < n; ++i) {
+ float w = quant_weights[i];
+ float slx = sumlx - w*x[i]*L[i];
+ float sl2 = suml2 - w*L[i]*L[i];
+ if (slx > 0 && sl2 > 0) {
+ int new_l = nearest_int(x[i] * sl2 / slx);
+ new_l = MIN(nmax, new_l);
+ if (new_l != L[i]) {
+ slx += w*x[i]*new_l;
+ sl2 += w*new_l*new_l;
+ if (slx*slx*suml2 > sumlx*sumlx*sl2) {
+ L[i] = new_l; sumlx = slx; suml2 = sl2;
+ ++n_changed;
+ }
+ }
+ }
+ }
+ if (!n_changed) {
+ break;
+ }
+ }
+ return sumlx/suml2;
+}
+
+static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restrict y, int k, const float * restrict quant_weights) {
+ GGML_ASSERT(quant_weights);
+ assert(k % QK_K == 0);
+ const int nb = k / QK_K;
+ const bool requantize = true;
+
+ uint8_t L[QK_K];
+ uint8_t Laux[16];
+ float mins[QK_K/16];
+ float scales[QK_K/16];
+ float sw[QK_K/16];
+ float weight[16];
+ uint8_t Ls[QK_K/16], Lm[QK_K/16];
+
+ for (int i = 0; i < nb; i++) {
+ memset(sw, 0, QK_K/16*sizeof(float));
+ float sumx2 = 0;
+ for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
+ float sigma2 = sumx2/QK_K;
+ for (int j = 0; j < QK_K/16; ++j) {
+ const float * restrict qw = quant_weights + QK_K * i + 16*j;
+ for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
+ for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
+ scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
+ }
+
+ float dm, mm;
+ dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
+ mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
+
+ y[i].d = GGML_FP32_TO_FP16(dm);
+ y[i].dmin = GGML_FP32_TO_FP16(mm);
+ dm = GGML_FP16_TO_FP32(y[i].d);
+ mm = GGML_FP16_TO_FP32(y[i].dmin);
+
+ for (int j = 0; j < QK_K/16; ++j) {
+ y[i].scales[j] = Ls[j] | (Lm[j] << 4);
+ }
+
+ if (requantize) {
+ for (int j = 0; j < QK_K/16; ++j) {
+ const float d = dm * (y[i].scales[j] & 0xF);
+ if (!d) continue;
+ const float m = mm * (y[i].scales[j] >> 4);
+ for (int ii = 0; ii < 16; ++ii) {
+ int l = nearest_int((x[16*j + ii] + m)/d);
+ l = MAX(0, MIN(3, l));
+ L[16*j + ii] = l;
+ }
+ }
+ }
+
+ for (int j = 0; j < QK_K; j += 128) {
+ for (int l = 0; l < 32; ++l) {
+ y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+ }
+ }
+
+ x += QK_K;
+ }
+}
+
+size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
+ if (!quant_weights) {
+ quantize_row_q2_K_ref(src, dst, (int64_t)nrow*n_per_row);
+ }
+ else {
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += row_size;
+ }
+ }
+ return nrow * row_size;
+}
+
+//========================= 3-bit (de)-quantization
+
+void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int nb = k / QK_K;
+
+ int8_t L[QK_K];
+ float scales[QK_K / 16];
+
+ for (int i = 0; i < nb; i++) {
+
+ float max_scale = 0;
+ float amax = 0;
+ for (int j = 0; j < QK_K/16; ++j) {
+ scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
+ float scale = fabsf(scales[j]);
+ if (scale > amax) {
+ amax = scale; max_scale = scales[j];
+ }
+ }
+
+ memset(y[i].scales, 0, 12);
+ if (max_scale) {
+ float iscale = -32.f/max_scale;
+ for (int j = 0; j < QK_K/16; ++j) {
+ int8_t l = nearest_int(iscale*scales[j]);
+ l = MAX(-32, MIN(31, l)) + 32;
+ if (j < 8) {
+ y[i].scales[j] = l & 0xF;
+ } else {
+ y[i].scales[j-8] |= ((l & 0xF) << 4);
+ }
+ l >>= 4;
+ y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
+ }
+ y[i].d = GGML_FP32_TO_FP16(1/iscale);
+ } else {
+ y[i].d = GGML_FP32_TO_FP16(0.f);
+ }
+
+ int8_t sc;
+ for (int j = 0; j < QK_K/16; ++j) {
+ sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
+ sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
+ float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+ if (!d) {
+ continue;
+ }
+ for (int ii = 0; ii < 16; ++ii) {
+ int l = nearest_int(x[16*j + ii]/d);
+ l = MAX(-4, MIN(3, l));
+ L[16*j + ii] = l + 4;
+ }
+ }
+
+ memset(y[i].hmask, 0, QK_K/8);
+ // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
+ int m = 0;
+ uint8_t hm = 1;
+ for (int j = 0; j < QK_K; ++j) {
+ if (L[j] > 3) {
+ y[i].hmask[m] |= hm;
+ L[j] -= 4;
+ }
+ if (++m == QK_K/8) {
+ m = 0; hm <<= 1;
+ }
+ }
+ for (int j = 0; j < QK_K; j += 128) {
+ for (int l = 0; l < 32; ++l) {
+ y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+ }
+ }
+
+ x += QK_K;
+ }
+}
+
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int nb = k / QK_K;
+
+ const uint32_t kmask1 = 0x03030303;
+ const uint32_t kmask2 = 0x0f0f0f0f;
+
+ uint32_t aux[4];
+ const int8_t * scales = (const int8_t*)aux;
+
+ for (int i = 0; i < nb; i++) {
+
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
+
+ const uint8_t * restrict q = x[i].qs;
+ const uint8_t * restrict hm = x[i].hmask;
+ uint8_t m = 1;
+
+ memcpy(aux, x[i].scales, 12);
+ uint32_t tmp = aux[2];
+ aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+ aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+ aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
+ aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+
+ int is = 0;
+ float dl;
+ for (int n = 0; n < QK_K; n += 128) {
+ int shift = 0;
+ for (int j = 0; j < 4; ++j) {
+
+ dl = d_all * (scales[is++] - 32);
+ for (int l = 0; l < 16; ++l) {
+ *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
+ }
+
+ dl = d_all * (scales[is++] - 32);
+ for (int l = 0; l < 16; ++l) {
+ *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
+ }
+
+ shift += 2;
+ m <<= 1;
+ }
+ q += 32;
+ }
+
+ }
+}
+
+void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) {
+ quantize_row_q3_K_ref(x, vy, k);
+}
+
+static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) {
+ assert(n_per_row % QK_K == 0);
+ const int nb = n_per_row / QK_K;
+
+ int8_t L[QK_K];
+ float scales[QK_K / 16];
+ float weight[16];
+ float sw[QK_K / 16];
+ int8_t Ls[QK_K / 16];
+
+ for (int i = 0; i < nb; i++) {
+
+ float sumx2 = 0;
+ for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
+ float sigma2 = 2*sumx2/QK_K;
+
+ for (int j = 0; j < QK_K/16; ++j) {
+ if (quant_weights) {
+ const float * qw = quant_weights + QK_K * i + 16*j;
+ for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j+l]*x[16*j+l]);
+ } else {
+ for (int l = 0; l < 16; ++l) weight[l] = x[16*j+l]*x[16*j+l];
+ }
+ float sumw = 0;
+ for (int l = 0; l < 16; ++l) sumw += weight[l];
+ sw[j] = sumw;
+
+ scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight);
+
+ }
+
+ memset(y[i].scales, 0, 12);
+
+ float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw);
+ for (int j = 0; j < QK_K/16; ++j) {
+ int l = Ls[j];
+ if (j < 8) {
+ y[i].scales[j] = l & 0xF;
+ } else {
+ y[i].scales[j-8] |= ((l & 0xF) << 4);
+ }
+ l >>= 4;
+ y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
+ }
+ y[i].d = GGML_FP32_TO_FP16(d_block);
+
+ int8_t sc;
+ for (int j = 0; j < QK_K/16; ++j) {
+ sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
+ sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
+ float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+ if (!d) {
+ continue;
+ }
+ for (int ii = 0; ii < 16; ++ii) {
+ int l = nearest_int(x[16*j + ii]/d);
+ l = MAX(-4, MIN(3, l));
+ L[16*j + ii] = l + 4;
+ }
+ }
+
+ memset(y[i].hmask, 0, QK_K/8);
+ // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
+ int m = 0;
+ uint8_t hm = 1;
+ for (int j = 0; j < QK_K; ++j) {
+ if (L[j] > 3) {
+ y[i].hmask[m] |= hm;
+ L[j] -= 4;
+ }
+ if (++m == QK_K/8) {
+ m = 0; hm <<= 1;
+ }
+ }
+ for (int j = 0; j < QK_K; j += 128) {
+ for (int l = 0; l < 32; ++l) {
+ y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+ }
+ }
+
+ x += QK_K;
+ }
+}
+
+size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
+ if (!quant_weights) {
+ quantize_row_q3_K_ref(src, dst, (int64_t)nrow*n_per_row);
+ }
+ else {
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += row_size;
+ }
+ }
+ return nrow * row_size;
+}
+
+// ====================== 4-bit (de)-quantization
+
+void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int nb = k / QK_K;
+
+ uint8_t L[QK_K];
+ uint8_t Laux[32];
+ float weights[32];
+ float mins[QK_K/32];
+ float scales[QK_K/32];
+
+ for (int i = 0; i < nb; i++) {
+ float max_scale = 0; // as we are deducting the min, scales are always positive
+ float max_min = 0;
+ for (int j = 0; j < QK_K/32; ++j) {
+ //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
+ float sum_x2 = 0;
+ for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
+ float av_x = sqrtf(sum_x2/32);
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+ scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
+ float scale = scales[j];
+ if (scale > max_scale) {
+ max_scale = scale;
+ }
+ float min = mins[j];
+ if (min > max_min) {
+ max_min = min;
+ }
+ }
+
+ float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
+ float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
+ for (int j = 0; j < QK_K/32; ++j) {
+ uint8_t ls = nearest_int(inv_scale*scales[j]);
+ uint8_t lm = nearest_int(inv_min*mins[j]);
+ ls = MIN(63, ls);
+ lm = MIN(63, lm);
+ if (j < 4) {
+ y[i].scales[j] = ls;
+ y[i].scales[j+4] = lm;
+ } else {
+ y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
+ y[i].scales[j-4] |= ((ls >> 4) << 6);
+ y[i].scales[j-0] |= ((lm >> 4) << 6);
+ }
+ }
+ y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
+ y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
+
+ uint8_t sc, m;
+ for (int j = 0; j < QK_K/32; ++j) {
+ get_scale_min_k4(j, y[i].scales, &sc, &m);
+ const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+ if (!d) continue;
+ const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
+ for (int ii = 0; ii < 32; ++ii) {
+ int l = nearest_int((x[32*j + ii] + dm)/d);
+ l = MAX(0, MIN(15, l));
+ L[32*j + ii] = l;
+ }
+ }
+
+ uint8_t * q = y[i].qs;
+ for (int j = 0; j < QK_K; j += 64) {
+ for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
+ q += 32;
+ }
+
+ x += QK_K;
+ }
+}
+
+void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int nb = k / QK_K;
+
+ for (int i = 0; i < nb; i++) {
+ const uint8_t * q = x[i].qs;
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+ const float min = GGML_FP16_TO_FP32(x[i].dmin);
+
+ int is = 0;
+ uint8_t sc, m;
+ for (int j = 0; j < QK_K; j += 64) {
+ get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
+ const float d1 = d * sc; const float m1 = min * m;
+ get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
+ const float d2 = d * sc; const float m2 = min * m;
+ for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
+ for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
+ q += 32; is += 2;
+ }
+ }
+}
+
+void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_q4_K * restrict y = vy;
+ quantize_row_q4_K_ref(x, y, k);
+}
+
+static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) {
+ assert(n_per_row % QK_K == 0);
+ const int64_t nb = n_per_row / QK_K;
+
+ uint8_t L[QK_K];
+ uint8_t Laux[32];
+ uint8_t Ls[QK_K/32];
+ uint8_t Lm[QK_K/32];
+ float weights[32];
+ float sw[QK_K/32];
+ float mins[QK_K/32];
+ float scales[QK_K/32];
+
+ for (int i = 0; i < nb; i++) {
+
+ float sum_x2 = 0;
+ for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
+ float sigma2 = 2*sum_x2/QK_K;
+ float av_x = sqrtf(sigma2);
+
+ for (int j = 0; j < QK_K/32; ++j) {
+ if (quant_weights) {
+ const float * qw = quant_weights + QK_K*i + 32*j;
+ for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);
+ } else {
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+ }
+ float sumw = 0;
+ for (int l = 0; l < 32; ++l) sumw += weights[l];
+ sw[j] = sumw;
+ scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
+ }
+
+ float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw);
+ float m_block = make_qp_quants(QK_K/32, 63, mins, Lm, sw);
+ for (int j = 0; j < QK_K/32; ++j) {
+ uint8_t ls = Ls[j];
+ uint8_t lm = Lm[j];
+ if (j < 4) {
+ y[i].scales[j] = ls;
+ y[i].scales[j+4] = lm;
+ } else {
+ y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
+ y[i].scales[j-4] |= ((ls >> 4) << 6);
+ y[i].scales[j-0] |= ((lm >> 4) << 6);
+ }
+ }
+ y[i].d = GGML_FP32_TO_FP16(d_block);
+ y[i].dmin = GGML_FP32_TO_FP16(m_block);
+
+ uint8_t sc, m;
+ for (int j = 0; j < QK_K/32; ++j) {
+ get_scale_min_k4(j, y[i].scales, &sc, &m);
+ const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+ if (!d) continue;
+ const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
+ for (int ii = 0; ii < 32; ++ii) {
+ int l = nearest_int((x[32*j + ii] + dm)/d);
+ l = MAX(0, MIN(15, l));
+ L[32*j + ii] = l;
+ }
+ }
+ uint8_t * q = y[i].qs;
+ for (int j = 0; j < QK_K; j += 64) {
+ for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
+ q += 32;
+ }
+
+ x += QK_K;
+
+ }
+}
+
+size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
+ if (!quant_weights) {
+ quantize_row_q4_K_ref(src, dst, (int64_t)nrow*n_per_row);
+ }
+ else {
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += row_size;
+ }
+ }
+ return nrow * row_size;
+}
+
+// ====================== 5-bit (de)-quantization
+
+void quantize_row_q5_K_ref(const float * restrict x, block_q5_K * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ uint8_t L[QK_K];
+ float mins[QK_K/32];
+ float scales[QK_K/32];
+ float weights[32];
+ uint8_t Laux[32];
+
+ for (int i = 0; i < nb; i++) {
+ float max_scale = 0; // as we are deducting the min, scales are always positive
+ float max_min = 0;
+ for (int j = 0; j < QK_K/32; ++j) {
+ //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
+ float sum_x2 = 0;
+ for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
+ float av_x = sqrtf(sum_x2/32);
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+ scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
+ float scale = scales[j];
+ if (scale > max_scale) {
+ max_scale = scale;
+ }
+ float min = mins[j];
+ if (min > max_min) {
+ max_min = min;
+ }
+ }
+
+ float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
+ float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
+ for (int j = 0; j < QK_K/32; ++j) {
+ uint8_t ls = nearest_int(inv_scale*scales[j]);
+ uint8_t lm = nearest_int(inv_min*mins[j]);
+ ls = MIN(63, ls);
+ lm = MIN(63, lm);
+ if (j < 4) {
+ y[i].scales[j] = ls;
+ y[i].scales[j+4] = lm;
+ } else {
+ y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
+ y[i].scales[j-4] |= ((ls >> 4) << 6);
+ y[i].scales[j-0] |= ((lm >> 4) << 6);
+ }
+ }
+ y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
+ y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
+
+ uint8_t sc, m;
+ for (int j = 0; j < QK_K/32; ++j) {
+ get_scale_min_k4(j, y[i].scales, &sc, &m);
+ const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+ if (!d) continue;
+ const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
+ for (int ii = 0; ii < 32; ++ii) {
+ int l = nearest_int((x[32*j + ii] + dm)/d);
+ l = MAX(0, MIN(31, l));
+ L[32*j + ii] = l;
+ }
+ }
+
+ uint8_t * restrict qh = y[i].qh;
+ uint8_t * restrict ql = y[i].qs;
+ memset(qh, 0, QK_K/8);
+
+ uint8_t m1 = 1, m2 = 2;
+ for (int n = 0; n < QK_K; n += 64) {
+ for (int j = 0; j < 32; ++j) {
+ int l1 = L[n + j];
+ if (l1 > 15) {
+ l1 -= 16; qh[j] |= m1;
+ }
+ int l2 = L[n + j + 32];
+ if (l2 > 15) {
+ l2 -= 16; qh[j] |= m2;
+ }
+ ql[j] = l1 | (l2 << 4);
+ }
+ m1 <<= 2; m2 <<= 2;
+ ql += 32;
+ }
+
+ x += QK_K;
+ }
+}
+
+void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ for (int i = 0; i < nb; i++) {
+ const uint8_t * ql = x[i].qs;
+ const uint8_t * qh = x[i].qh;
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+ const float min = GGML_FP16_TO_FP32(x[i].dmin);
+
+ int is = 0;
+ uint8_t sc, m;
+ uint8_t u1 = 1, u2 = 2;
+ for (int j = 0; j < QK_K; j += 64) {
+ get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
+ const float d1 = d * sc; const float m1 = min * m;
+ get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
+ const float d2 = d * sc; const float m2 = min * m;
+ for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
+ for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
+ ql += 32; is += 2;
+ u1 <<= 2; u2 <<= 2;
+ }
+ }
+}
+
+void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_q5_K * restrict y = vy;
+ quantize_row_q5_K_ref(x, y, k);
+}
+
+static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) {
+ assert(n_per_row % QK_K == 0);
+ const int64_t nb = n_per_row / QK_K;
+
+ uint8_t L[QK_K];
+ uint8_t Laux[32];
+ uint8_t Ls[QK_K/32];
+ uint8_t Lm[QK_K/32];
+ float mins[QK_K/32];
+ float scales[QK_K/32];
+ float sw[QK_K/32];
+ float weights[32];
+
+ for (int i = 0; i < nb; i++) {
+
+ float sum_x2 = 0;
+ for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
+ float sigma2 = 2*sum_x2/QK_K;
+ float av_x = sqrtf(sigma2);
+
+ for (int j = 0; j < QK_K/32; ++j) {
+ if (quant_weights) {
+ const float * qw = quant_weights + QK_K*i + 32*j;
+ for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);
+ } else {
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+ }
+ float sumw = 0;
+ for (int l = 0; l < 32; ++l) sumw += weights[l];
+ sw[j] = sumw;
+
+ scales[j] = make_qkx3_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
+ }
+
+ float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw);
+ float m_block = make_qp_quants(QK_K/32, 63, mins, Lm, sw);
+
+ for (int j = 0; j < QK_K/32; ++j) {
+ uint8_t ls = Ls[j];
+ uint8_t lm = Lm[j];
+ ls = MIN(63, ls);
+ lm = MIN(63, lm);
+ if (j < 4) {
+ y[i].scales[j] = ls;
+ y[i].scales[j+4] = lm;
+ } else {
+ y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
+ y[i].scales[j-4] |= ((ls >> 4) << 6);
+ y[i].scales[j-0] |= ((lm >> 4) << 6);
+ }
+ }
+ y[i].d = GGML_FP32_TO_FP16(d_block);
+ y[i].dmin = GGML_FP32_TO_FP16(m_block);
+
+ uint8_t sc, m;
+ for (int j = 0; j < QK_K/32; ++j) {
+ get_scale_min_k4(j, y[i].scales, &sc, &m);
+ const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+ if (!d) continue;
+ const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
+ for (int ii = 0; ii < 32; ++ii) {
+ int l = nearest_int((x[32*j + ii] + dm)/d);
+ l = MAX(0, MIN(31, l));
+ L[32*j + ii] = l;
+ }
+ }
+
+ uint8_t * restrict qh = y[i].qh;
+ uint8_t * restrict ql = y[i].qs;
+ memset(qh, 0, QK_K/8);
+
+ uint8_t m1 = 1, m2 = 2;
+ for (int n = 0; n < QK_K; n += 64) {
+ for (int j = 0; j < 32; ++j) {
+ int l1 = L[n + j];
+ if (l1 > 15) {
+ l1 -= 16; qh[j] |= m1;
+ }
+ int l2 = L[n + j + 32];
+ if (l2 > 15) {
+ l2 -= 16; qh[j] |= m2;
+ }
+ ql[j] = l1 | (l2 << 4);
+ }
+ m1 <<= 2; m2 <<= 2;
+ ql += 32;
+ }
+
+ x += QK_K;
+
+ }
+}
+
+size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
+ if (!quant_weights) {
+ quantize_row_q5_K_ref(src, dst, (int64_t)nrow*n_per_row);
+ }
+ else {
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += row_size;
+ }
+ }
+ return nrow * row_size;
+}
+
+// ====================== 6-bit (de)-quantization
+
+void quantize_row_q6_K_ref(const float * restrict x, block_q6_K * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ int8_t L[QK_K];
+ float scales[QK_K/16];
+
+ for (int i = 0; i < nb; i++) {
+
+ float max_scale = 0;
+ float max_abs_scale = 0;
+
+ for (int ib = 0; ib < QK_K/16; ++ib) {
+
+ const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);
+ scales[ib] = scale;
+
+ const float abs_scale = fabsf(scale);
+ if (abs_scale > max_abs_scale) {
+ max_abs_scale = abs_scale;
+ max_scale = scale;
+ }
+
+ }
+
+ if (max_abs_scale < GROUP_MAX_EPS) {
+ memset(&y[i], 0, sizeof(block_q6_K));
+ y[i].d = GGML_FP32_TO_FP16(0.f);
+ x += QK_K;
+ continue;
+ }
+
+ float iscale = -128.f/max_scale;
+ y[i].d = GGML_FP32_TO_FP16(1/iscale);
+ for (int ib = 0; ib < QK_K/16; ++ib) {
+ y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
+ }
+
+ for (int j = 0; j < QK_K/16; ++j) {
+ float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
+ if (!d) {
+ continue;
+ }
+ for (int ii = 0; ii < 16; ++ii) {
+ int l = nearest_int(x[16*j + ii]/d);
+ l = MAX(-32, MIN(31, l));
+ L[16*j + ii] = l + 32;
+ }
+ }
+
+ uint8_t * restrict ql = y[i].ql;
+ uint8_t * restrict qh = y[i].qh;
+ for (int j = 0; j < QK_K; j += 128) {
+ for (int l = 0; l < 32; ++l) {
+ const uint8_t q1 = L[j + l + 0] & 0xF;
+ const uint8_t q2 = L[j + l + 32] & 0xF;
+ const uint8_t q3 = L[j + l + 64] & 0xF;
+ const uint8_t q4 = L[j + l + 96] & 0xF;
+ ql[l+ 0] = q1 | (q3 << 4);
+ ql[l+32] = q2 | (q4 << 4);
+ qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
+ }
+ ql += 64;
+ qh += 32;
+ }
+
+ x += QK_K;
+ }
+}
+
+void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ for (int i = 0; i < nb; i++) {
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+
+ const uint8_t * restrict ql = x[i].ql;
+ const uint8_t * restrict qh = x[i].qh;
+ const int8_t * restrict sc = x[i].scales;
+
+ for (int n = 0; n < QK_K; n += 128) {
+ for (int l = 0; l < 32; ++l) {
+ int is = l/16;
+ const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+ const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+ const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+ const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+ y[l + 0] = d * sc[is + 0] * q1;
+ y[l + 32] = d * sc[is + 2] * q2;
+ y[l + 64] = d * sc[is + 4] * q3;
+ y[l + 96] = d * sc[is + 6] * q4;
+ }
+ y += 128;
+ ql += 64;
+ qh += 32;
+ sc += 8;
+ }
+ }
+}
+
+void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_q6_K * restrict y = vy;
+ quantize_row_q6_K_ref(x, y, k);
+}
+
+static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) {
+ assert(n_per_row % QK_K == 0);
+ const int64_t nb = n_per_row / QK_K;
+
+ int8_t L[QK_K];
+ float scales[QK_K/16];
+ //float weights[16];
+
+ for (int i = 0; i < nb; i++) {
+
+ //float sum_x2 = 0;
+ //for (int j = 0; j < QK_K; ++j) sum_x2 += x[j]*x[j];
+ //float sigma2 = sum_x2/QK_K;
+
+ float max_scale = 0;
+ float max_abs_scale = 0;
+
+ for (int ib = 0; ib < QK_K/16; ++ib) {
+
+ float scale;
+ if (quant_weights) {
+ const float * qw = quant_weights + QK_K*i + 16*ib;
+ //for (int j = 0; j < 16; ++j) weights[j] = qw[j] * sqrtf(sigma2 + x[16*ib + j]*x[16*ib + j]);
+ //scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, weights);
+ scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, qw);
+ } else {
+ scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);
+ }
+ scales[ib] = scale;
+
+ const float abs_scale = fabsf(scale);
+ if (abs_scale > max_abs_scale) {
+ max_abs_scale = abs_scale;
+ max_scale = scale;
+ }
+
+ }
+
+ if (max_abs_scale < GROUP_MAX_EPS) {
+ memset(&y[i], 0, sizeof(block_q6_K));
+ y[i].d = GGML_FP32_TO_FP16(0.f);
+ x += QK_K;
+ continue;
+ }
+
+ float iscale = -128.f/max_scale;
+ y[i].d = GGML_FP32_TO_FP16(1/iscale);
+ for (int ib = 0; ib < QK_K/16; ++ib) {
+ y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
+ }
+
+ for (int j = 0; j < QK_K/16; ++j) {
+ float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
+ if (!d) {
+ continue;
+ }
+ for (int ii = 0; ii < 16; ++ii) {
+ int l = nearest_int(x[16*j + ii]/d);
+ l = MAX(-32, MIN(31, l));
+ L[16*j + ii] = l + 32;
+ }
+ }
+
+ uint8_t * restrict ql = y[i].ql;
+ uint8_t * restrict qh = y[i].qh;
+ for (int j = 0; j < QK_K; j += 128) {
+ for (int l = 0; l < 32; ++l) {
+ const uint8_t q1 = L[j + l + 0] & 0xF;
+ const uint8_t q2 = L[j + l + 32] & 0xF;
+ const uint8_t q3 = L[j + l + 64] & 0xF;
+ const uint8_t q4 = L[j + l + 96] & 0xF;
+ ql[l+ 0] = q1 | (q3 << 4);
+ ql[l+32] = q2 | (q4 << 4);
+ qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
+ }
+ ql += 64;
+ qh += 32;
+ }
+
+ x += QK_K;
+
+ }
+}
+
+size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
+ if (!quant_weights) {
+ quantize_row_q6_K_ref(src, dst, (int64_t)nrow*n_per_row);
+ }
+ else {
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += row_size;
+ }
+ }
+ return nrow * row_size;
+}
+
+static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int64_t n_per_row, const float * quant_weights) {
+ static_assert(QK4_0 == 32, "QK4_0 must be 32");
+
+ if (!quant_weights) {
+ quantize_row_q4_0_ref(x, y, n_per_row);
+ return;
+ }
+
+ float weight[QK4_0];
+ int8_t L[QK4_0];
+
+ float sum_x2 = 0;
+ for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
+ float sigma2 = sum_x2/n_per_row;
+
+ const int64_t nb = n_per_row/QK4_0;
+ for (int ib = 0; ib < nb; ++ib) {
+ const float * xb = x + QK4_0 * ib;
+ const float * qw = quant_weights + QK4_0 * ib;
+ for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight);
+ y[ib].d = GGML_FP32_TO_FP16(d);
+ for (int j = 0; j < 16; ++j) {
+ y[ib].qs[j] = L[j] | (L[j+16] << 4);
+ }
+ }
+}
+
+size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ if (!quant_weights) {
+ quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row);
+ return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
+ }
+ size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += row_size;
+ }
+ return nrow * row_size;
+}
+
+static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int64_t n_per_row, const float * quant_weights) {
+ static_assert(QK4_1 == 32, "QK4_1 must be 32");
+
+ if (!quant_weights) {
+ quantize_row_q4_1_ref(x, y, n_per_row);
+ return;
+ }
+
+ float weight[QK4_1];
+ uint8_t L[QK4_1], Laux[QK4_1];
+
+ float sum_x2 = 0;
+ for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
+ float sigma2 = sum_x2/n_per_row;
+
+ const int64_t nb = n_per_row/QK4_1;
+ for (int ib = 0; ib < nb; ++ib) {
+ const float * xb = x + QK4_1 * ib;
+ const float * qw = quant_weights + QK4_1 * ib;
+ for (int j = 0; j < QK4_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ float min;
+ float d = make_qkx3_quants(QK4_1, 15, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false);
+ y[ib].d = GGML_FP32_TO_FP16(d);
+ y[ib].m = GGML_FP32_TO_FP16(-min);
+ for (int j = 0; j < 16; ++j) {
+ y[ib].qs[j] = L[j] | (L[j+16] << 4);
+ }
+ }
+}
+
+size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ if (!quant_weights) {
+ quantize_row_q4_1_ref(src, dst, (int64_t)nrow*n_per_row);
+ return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
+ }
+ size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += row_size;
+ }
+ return nrow * row_size;
+}
+
+static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int64_t n_per_row, const float * quant_weights) {
+ static_assert(QK5_0 == 32, "QK5_0 must be 32");
+
+ if (!quant_weights) {
+ quantize_row_q5_0_ref(x, y, n_per_row);
+ return;
+ }
+
+ float weight[QK5_0];
+ int8_t L[QK5_0];
+
+ float sum_x2 = 0;
+ for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
+ float sigma2 = sum_x2/n_per_row;
+
+ const int64_t nb = n_per_row/QK5_0;
+ for (int ib = 0; ib < nb; ++ib) {
+ const float * xb = x + QK5_0 * ib;
+ const float * qw = quant_weights + QK5_0 * ib;
+ for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ float d = make_qx_quants(QK5_0, 16, xb, L, 1, weight);
+ y[ib].d = GGML_FP32_TO_FP16(d);
+
+ uint32_t qh = 0;
+
+ for (int j = 0; j < 16; ++j) {
+ const uint8_t xi0 = L[j];
+ const uint8_t xi1 = L[j+16];
+ y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+ // get the 5-th bit and store it in qh at the right position
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+ }
+
+ memcpy(&y[ib].qh, &qh, sizeof(qh));
+ }
+}
+
+size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ if (!quant_weights) {
+ quantize_row_q5_0_ref(src, dst, (int64_t)nrow*n_per_row);
+ return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
+ }
+ size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += row_size;
+ }
+ return nrow * row_size;
+}
+
+static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int64_t n_per_row, const float * quant_weights) {
+ static_assert(QK5_1 == 32, "QK5_1 must be 32");
+
+ if (!quant_weights) {
+ quantize_row_q5_1_ref(x, y, n_per_row);
+ return;
+ }
+
+ float weight[QK5_1];
+ uint8_t L[QK5_1], Laux[QK5_1];
+
+ float sum_x2 = 0;
+ for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
+ float sigma2 = sum_x2/n_per_row;
+
+ const int64_t nb = n_per_row/QK5_1;
+ for (int ib = 0; ib < nb; ++ib) {
+ const float * xb = x + QK5_1 * ib;
+ const float * qw = quant_weights + QK5_1 * ib;
+ for (int j = 0; j < QK5_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ float min;
+ float d = make_qkx3_quants(QK5_1, 31, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false);
+ y[ib].d = GGML_FP32_TO_FP16(d);
+ y[ib].m = GGML_FP32_TO_FP16(-min);
+
+ uint32_t qh = 0;
+ for (int j = 0; j < 16; ++j) {
+ const uint8_t xi0 = L[j];
+ const uint8_t xi1 = L[j+16];
+ y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+ // get the 5-th bit and store it in qh at the right position
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+ }
+ memcpy(&y[ib].qh, &qh, sizeof(qh));
+ }
+}
+
+size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ if (!quant_weights) {
+ quantize_row_q5_1_ref(src, dst, (int64_t)nrow*n_per_row);
+ return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
+ }
+ size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += row_size;
+ }
+ return nrow * row_size;
+}
+
+size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ (void)quant_weights; // not used
+ const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);
+ quantize_row_q8_0_ref(src, dst, (int64_t)nrow*n_per_row);
+ return nrow * row_size;
+}
+
+// ====================== "True" 2-bit (de)-quantization
+
+void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ uint32_t aux32[2];
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+
+ for (int i = 0; i < nb; i++) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t));
+ const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f;
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
+ const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
+ for (int j = 0; j < 8; ++j) {
+ y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+ }
+ y += 8;
+ }
+ }
+ }
+}
+
+// ====================== 2.3125 bpw (de)-quantization
+
+void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ float db[2];
+
+ for (int i = 0; i < nb; i++) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;
+ db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f;
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (x[i].qs[4*ib32 + l] & 511));
+ const uint8_t signs = ksigns_iq2xs[x[i].qs[4*ib32 + l] >> 9];
+ for (int j = 0; j < 8; ++j) {
+ y[j] = db[l/2] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+ }
+ y += 8;
+ }
+ }
+ }
+}
+
+// ====================== 2.5625 bpw (de)-quantization
+
+void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ float db[2];
+
+ for (int i = 0; i < nb; i++) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+ const uint8_t * qs = x[i].qs;
+ const uint8_t * qh = x[i].qh;
+ const uint8_t * signs = qs + QK_K/8;
+
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;
+ db[1] = d * (0.5f + (x[i].scales[ib32] >> 4)) * 0.25f;
+ for (int l = 0; l < 4; ++l) {
+ const float dl = db[l/2];
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
+ for (int j = 0; j < 8; ++j) {
+ y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
+ }
+ y += 8;
+ }
+ qs += 4;
+ signs += 4;
+ }
+ }
+}
+
+// ====================== 3.0625 bpw (de)-quantization
+
+void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ uint32_t aux32;
+
+ for (int i = 0; i < nb; i++) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+ const uint8_t * qs = x[i].qs;
+ const uint8_t * scales_and_signs = qs + QK_K/4;
+
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t));
+ const float db = d * (0.5f + (aux32 >> 28)) * 0.5f;
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]);
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]);
+ for (int j = 0; j < 4; ++j) {
+ y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+ y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+ }
+ y += 8;
+ }
+ qs += 8;
+ }
+ }
+}
+
+// ====================== 3.3125 bpw (de)-quantization
+
+void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ for (int i = 0; i < nb; i++) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+ const uint8_t * qs = x[i].qs;
+ const uint8_t * qh = x[i].qh;
+ const uint8_t * signs = x[i].signs;
+
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf));
+ const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >> 4));
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
+ for (int j = 0; j < 4; ++j) {
+ y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
+ y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
+ }
+ y += 8;
+ }
+ qs += 8;
+ signs += 4;
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)));
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)));
+ for (int j = 0; j < 4; ++j) {
+ y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
+ y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
+ }
+ y += 8;
+ }
+ qh += 2;
+ qs += 8;
+ signs += 4;
+ }
+ }
+}
+
+// ====================== 1.5625 bpw (de)-quantization
+
+void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ for (int i = 0; i < nb; i++) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+ const uint8_t * qs = x[i].qs;
+ const uint16_t * qh = x[i].qh;
+
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ const float dl = d * (2*((qh[ib] >> 12) & 7) + 1);
+ const float delta = qh[ib] & 0x8000 ? -IQ1S_DELTA : IQ1S_DELTA;
+ for (int l = 0; l < 4; ++l) {
+ const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
+ for (int j = 0; j < 8; ++j) {
+ y[j] = dl * (grid[j] + delta);
+ }
+ y += 8;
+ }
+ qs += 4;
+ }
+ }
+}
+
+void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ float delta[4];
+ uint16_t idx[4];
+
+ iq1m_scale_t scale;
+
+ for (int i = 0; i < nb; i++) {
+
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+ const float d = GGML_FP16_TO_FP32(scale.f16);
+
+ const uint8_t * qs = x[i].qs;
+ const uint8_t * qh = x[i].qh;
+
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1);
+ const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1);
+
+ idx[0] = qs[0] | ((qh[0] << 8) & 0x700);
+ idx[1] = qs[1] | ((qh[0] << 4) & 0x700);
+ idx[2] = qs[2] | ((qh[1] << 8) & 0x700);
+ idx[3] = qs[3] | ((qh[1] << 4) & 0x700);
+ delta[0] = qh[0] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA;
+ delta[1] = qh[0] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA;
+ delta[2] = qh[1] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA;
+ delta[3] = qh[1] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA;
+ for (int l = 0; l < 2; ++l) {
+ const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
+ for (int j = 0; j < 8; ++j) {
+ y[j] = dl1 * (grid[j] + delta[l]);
+ }
+ y += 8;
+ }
+ for (int l = 2; l < 4; ++l) {
+ const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
+ for (int j = 0; j < 8; ++j) {
+ y[j] = dl2 * (grid[j] + delta[l]);
+ }
+ y += 8;
+ }
+ qs += 4;
+ qh += 2;
+ }
+ }
+}
+
+static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+
+void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK4_NL == 0);
+ const int64_t nb = k / QK4_NL;
+
+ for (int i = 0; i < nb; i++) {
+
+ const uint8_t * qs = x[i].qs;
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+ for (int j = 0; j < QK4_NL/2; ++j) {
+ y[j+ 0] = d * kvalues_iq4nl[qs[j] & 0xf];
+ y[j+QK4_NL/2] = d * kvalues_iq4nl[qs[j] >> 4];
+ }
+ y += QK4_NL;
+ qs += QK4_NL/2;
+ }
+}
+
+void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ for (int i = 0; i < nb; i++) {
+
+ const uint8_t * qs = x[i].qs;
+
+ const float d = GGML_FP16_TO_FP32(x[i].d);
+
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4);
+ const float dl = d * (ls - 32);
+ for (int j = 0; j < 16; ++j) {
+ y[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf];
+ y[j+16] = dl * kvalues_iq4nl[qs[j] >> 4];
+ }
+ y += 32;
+ qs += 16;
+ }
+ }
+}
+
+//===================================== Q8_K ==============================================
+
+void quantize_row_q8_K_ref(const float * restrict x, block_q8_K * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ for (int i = 0; i < nb; i++) {
+
+ float max = 0;
+ float amax = 0;
+ for (int j = 0; j < QK_K; ++j) {
+ float ax = fabsf(x[j]);
+ if (ax > amax) {
+ amax = ax; max = x[j];
+ }
+ }
+ if (!amax) {
+ y[i].d = 0;
+ memset(y[i].qs, 0, QK_K);
+ x += QK_K;
+ continue;
+ }
+ //const float iscale = -128.f/max;
+ // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward
+ const float iscale = -127.f/max;
+ for (int j = 0; j < QK_K; ++j) {
+ int v = nearest_int(iscale*x[j]);
+ y[i].qs[j] = MIN(127, v);
+ }
+ for (int j = 0; j < QK_K/16; ++j) {
+ int sum = 0;
+ for (int ii = 0; ii < 16; ++ii) {
+ sum += y[i].qs[j*16 + ii];
+ }
+ y[i].bsums[j] = sum;
+ }
+ y[i].d = 1/iscale;
+ x += QK_K;
+ }
+}
+
+void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ const int64_t nb = k / QK_K;
+
+ for (int i = 0; i < nb; i++) {
+ for (int j = 0; j < QK_K; ++j) {
+ *y++ = x[i].d * x[i].qs[j];
+ }
+ }
+}
+
+void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
+ quantize_row_q8_K_ref(x, y, k);
+}
+
+//===================================== Dot ptoducts =================================
+
+//
+// Helper functions
+//
+#if __AVX__ || __AVX2__ || __AVX512F__
+
+// shuffles to pick the required scales in dot products
+static inline __m256i get_scale_shuffle_q3k(int i) {
+ static const uint8_t k_shuffle[128] = {
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
+ };
+ return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+static inline __m256i get_scale_shuffle_k4(int i) {
+ static const uint8_t k_shuffle[256] = {
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
+ 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
+ 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
+ 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
+ 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
+ };
+ return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+static inline __m128i get_scale_shuffle(int i) {
+ static const uint8_t k_shuffle[128] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
+ 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
+ 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
+ 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
+ 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
+ 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
+ 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
+ 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
+ };
+ return _mm_loadu_si128((const __m128i*)k_shuffle + i);
+}
+#elif defined(__loongarch_asx)
+// shuffles to pick the required scales in dot products
+static inline __m256i get_scale_shuffle_q3k(int i) {
+ static const uint8_t k_shuffle[128] = {
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
+ };
+ return __lasx_xvld((const __m256i*)k_shuffle + i, 0);
+}
+static inline __m256i get_scale_shuffle_k4(int i) {
+ static const uint8_t k_shuffle[256] = {
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
+ 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
+ 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
+ 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
+ 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
+ };
+ return __lasx_xvld((const __m256i*)k_shuffle + i, 0);
+}
+static inline __m128i get_scale_shuffle(int i) {
+ static const uint8_t k_shuffle[128] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
+ 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
+ 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
+ 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
+ 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
+ 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
+ 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
+ 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
+ };
+ return __lsx_vld((const __m128i*)k_shuffle + i, 0);
+}
+#endif
+
+void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
+ return;
+ }
+#endif
+ const int qk = QK8_0;
+ const int nb = n / qk;
+
+ assert(n % qk == 0);
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ assert((nrc == 2) || (nrc == 1));
+#else
+ assert(nrc == 1);
+#endif
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_q4_0 * restrict x = vx;
+ const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ if (nrc == 2) {
+ const block_q4_0 * restrict vx0 = vx;
+ const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
+ const block_q8_0 * restrict vy0 = vy;
+ const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
+
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+
+ for (int i = 0; i < nb; i++) {
+ const block_q4_0 * restrict b_x0 = &vx0[i];
+ const block_q4_0 * restrict b_x1 = &vx1[i];
+ const block_q8_0 * restrict b_y0 = &vy0[i];
+ const block_q8_0 * restrict b_y1 = &vy1[i];
+
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
+ const int8x16_t s8b = vdupq_n_s8(0x8);
+
+ const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
+ const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
+
+ // 4-bit -> 8-bit
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+ // sub 8
+ const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);
+ const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);
+ const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);
+ const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);
+
+ // load y
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
+
+ float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
+
+ float32x4_t scale = vld1q_f32(_scale);
+
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
+ l1, r1)), l2, r2)), l3, r3))), scale);
+ }
+ float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+
+ vst1_f32(s, vget_low_f32(sumv2));
+ vst1_f32(s + bs, vget_high_f32(sumv2));
+ return;
+ }
+#endif
+
+ int ib = 0;
+ float sumf = 0;
+
+#if defined(__ARM_FEATURE_SVE)
+ if (svcntb() == QK8_0) {
+ const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
+ const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
+
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
+
+ for (; ib + 1 < nb; ib += 2) {
+ const block_q4_0 * restrict x0 = &x[ib + 0];
+ const block_q4_0 * restrict x1 = &x[ib + 1];
+ const block_q8_0 * restrict y0 = &y[ib + 0];
+ const block_q8_0 * restrict y1 = &y[ib + 1];
+
+ // load x
+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
+
+ // 4-bit -> 8-bit
+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
+
+ // sub 8
+ const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
+ const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
+
+ // load y
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+
+ // dot product
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+ }
+
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+ }
+#elif defined(__ARM_NEON)
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+ for (; ib + 1 < nb; ib += 2) {
+ const block_q4_0 * restrict x0 = &x[ib + 0];
+ const block_q4_0 * restrict x1 = &x[ib + 1];
+ const block_q8_0 * restrict y0 = &y[ib + 0];
+ const block_q8_0 * restrict y1 = &y[ib + 1];
+
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
+ const int8x16_t s8b = vdupq_n_s8(0x8);
+
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+ // 4-bit -> 8-bit
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+ // sub 8
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
+
+ // load y
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+ // dot product into int32x4_t
+ const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
+ const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
+
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+ }
+
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+
+ // Main loop
+ for (; ib < nb; ++ib) {
+ /* Compute combined scale for the block */
+ const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
+
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+ const __m256i off = _mm256_set1_epi8( 8 );
+ qx = _mm256_sub_epi8( qx, off );
+
+ __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+ /* Multiply q with scale and accumulate */
+ acc = _mm256_fmadd_ps( d, q, acc );
+ }
+
+ sumf = hsum_float_8(acc);
+#elif defined(__AVX__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+
+ // Main loop
+ for (; ib < nb; ++ib) {
+ // Compute combined scale for the block
+ const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
+
+ const __m128i lowMask = _mm_set1_epi8(0xF);
+ const __m128i off = _mm_set1_epi8(8);
+
+ const __m128i tmp = _mm_loadu_si128((const __m128i *)x[ib].qs);
+
+ __m128i bx_0 = _mm_and_si128(lowMask, tmp);
+ __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
+ bx_0 = _mm_sub_epi8(bx_0, off);
+ const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+ bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
+ by_0 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
+ bx_0 = _mm_sub_epi8(bx_0, off);
+ const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
+
+ // Convert int32_t to float
+ __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
+
+ // Apply the scale, and accumulate
+ acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
+ }
+
+ sumf = hsum_float_8(acc);
+#elif defined(__SSSE3__)
+ // set constants
+ const __m128i lowMask = _mm_set1_epi8(0xF);
+ const __m128i off = _mm_set1_epi8(8);
+
+ // Initialize accumulator with zeros
+ __m128 acc_0 = _mm_setzero_ps();
+ __m128 acc_1 = _mm_setzero_ps();
+ __m128 acc_2 = _mm_setzero_ps();
+ __m128 acc_3 = _mm_setzero_ps();
+
+ for (; ib + 1 < nb; ib += 2) {
+ _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0);
+ _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0);
+
+ // Compute combined scale for the block 0 and 1
+ const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
+
+ const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs);
+
+ __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+ __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
+ bx_0 = _mm_sub_epi8(bx_0, off);
+ const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+ __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+ __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
+ bx_1 = _mm_sub_epi8(bx_1, off);
+ const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+ _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
+ _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
+
+ // Compute combined scale for the block 2 and 3
+ const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
+
+ const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
+
+ __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+ __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
+ bx_2 = _mm_sub_epi8(bx_2, off);
+ const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+ __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+ __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16));
+ bx_3 = _mm_sub_epi8(bx_3, off);
+ const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+ // Convert int32_t to float
+ __m128 p0 = _mm_cvtepi32_ps(i32_0);
+ __m128 p1 = _mm_cvtepi32_ps(i32_1);
+ __m128 p2 = _mm_cvtepi32_ps(i32_2);
+ __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+ // Apply the scale
+ __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
+ __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
+ __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
+ __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
+
+ // Acummulate
+ acc_0 = _mm_add_ps(p0_d, acc_0);
+ acc_1 = _mm_add_ps(p1_d, acc_1);
+ acc_2 = _mm_add_ps(p2_d, acc_2);
+ acc_3 = _mm_add_ps(p3_d, acc_3);
+ }
+
+ sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
+#elif defined(__riscv_v_intrinsic)
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+ for (; ib < nb; ++ib) {
+ // load elements
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
+
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
+
+ // mask and store lower part of x, and then upper part
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+ // subtract offset
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
+
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+ sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
+ }
+
+#elif defined(__POWER9_VECTOR__)
+ const vector signed char lowMask = vec_splats((signed char)0xF);
+ const vector signed int v0 = vec_splats((int32_t)0);
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+ const vector signed char v8 = vec_splats((signed char)0x8);
+
+ vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 8
+ for (; ib < nb; ++ib) {
+ __builtin_prefetch(x[ib].qs, 0, 1);
+ __builtin_prefetch(y[ib].qs, 0, 1);
+
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+ vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+ vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+ vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+ vector signed char q4x0 = vec_and(qxs, lowMask);
+ vector signed char q4x1 = vec_sr(qxs, v4);
+
+ q4x0 = vec_sub(q4x0, v8);
+ q4x1 = vec_sub(q4x1, v8);
+
+ vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));
+ vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));
+
+ vector signed int vsumi0 = v0;
+
+ vsumi0 = vec_sum4s(qv0, vsumi0);
+ vsumi0 = vec_sum4s(qv1, vsumi0);
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ }
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ sumf = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+ // Initialize accumulator with zeros
+ __m256 acc = (__m256)__lasx_xvldi(0);
+
+ // Main loop
+ for (; ib < nb; ++ib) {
+ /* Compute combined scale for the block */
+ const __m256 d = __lasx_xvreplfr2vr_s( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
+
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+ const __m256i off = __lasx_xvreplgr2vr_b( 8 );
+ qx = __lasx_xvsub_b( qx, off );
+
+ __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
+
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+ /* Multiply q with scale and accumulate */
+ acc = __lasx_xvfmadd_s( d, q, acc );
+ }
+
+ sumf = hsum_float_8(acc);
+#elif defined(__loongarch_sx)
+ // set constants
+ const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
+ const __m128i off = __lsx_vreplgr2vr_b(8);
+
+ // Initialize accumulator with zeros
+ __m128 acc_0 = __lsx_vldi(0);
+ __m128 acc_1 = __lsx_vldi(0);
+ __m128 acc_2 = __lsx_vldi(0);
+ __m128 acc_3 = __lsx_vldi(0);
+
+ for (; ib + 1 < nb; ib += 2) {
+
+ // Compute combined scale for the block 0 and 1
+ const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
+
+ const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
+
+ __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1);
+ __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0);
+ bx_0 = __lsx_vsub_b(bx_0, off);
+ const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+ __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4));
+ __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0);
+ bx_1 = __lsx_vsub_b(bx_1, off);
+ const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+ //_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
+ //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
+
+ // Compute combined scale for the block 2 and 3
+ const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
+
+ const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
+
+ __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3);
+ __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0);
+ bx_2 = __lsx_vsub_b(bx_2, off);
+ const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+ __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4));
+ __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0);
+ bx_3 = __lsx_vsub_b(bx_3, off);
+ const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+ // Convert int32_t to float
+ __m128 p0 = __lsx_vffint_s_w(i32_0);
+ __m128 p1 = __lsx_vffint_s_w(i32_1);
+ __m128 p2 = __lsx_vffint_s_w(i32_2);
+ __m128 p3 = __lsx_vffint_s_w(i32_3);
+
+ // Apply the scale
+ __m128 p0_d = __lsx_vfmul_s( d_0_1, p0 );
+ __m128 p1_d = __lsx_vfmul_s( d_0_1, p1 );
+ __m128 p2_d = __lsx_vfmul_s( d_2_3, p2 );
+ __m128 p3_d = __lsx_vfmul_s( d_2_3, p3 );
+
+ // Acummulate
+ acc_0 = __lsx_vfadd_s(p0_d, acc_0);
+ acc_1 = __lsx_vfadd_s(p1_d, acc_1);
+ acc_2 = __lsx_vfadd_s(p2_d, acc_2);
+ acc_3 = __lsx_vfadd_s(p3_d, acc_3);
+ }
+
+ sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
+#endif
+ for (; ib < nb; ++ib) {
+ int sumi = 0;
+
+ for (int j = 0; j < qk/2; ++j) {
+ const int v0 = (x[ib].qs[j] & 0x0F) - 8;
+ const int v1 = (x[ib].qs[j] >> 4) - 8;
+
+ sumi += (v0 * y[ib].qs[j]) + (v1 * y[ib].qs[j + qk/2]);
+ }
+
+ sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
+ }
+
+ *s = sumf;
+}
+
+void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1)) {
+ return;
+ }
+#endif
+ const int qk = QK8_1;
+ const int nb = n / qk;
+
+ assert(n % qk == 0);
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ assert((nrc == 2) || (nrc == 1));
+#else
+ assert(nrc == 1);
+#endif
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_q4_1 * restrict x = vx;
+ const block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ if (nrc == 2) {
+ const block_q4_1 * restrict vx0 = vx;
+ const block_q4_1 * restrict vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx);
+ const block_q8_1 * restrict vy0 = vy;
+ const block_q8_1 * restrict vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by);
+
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+ float32x4_t summs0 = vdupq_n_f32(0.0f);
+
+ for (int i = 0; i < nb; i++) {
+ const block_q4_1 * restrict b_x0 = &vx0[i];
+ const block_q4_1 * restrict b_x1 = &vx1[i];
+ const block_q8_1 * restrict b_y0 = &vy0[i];
+ const block_q8_1 * restrict b_y1 = &vy1[i];
+
+ float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
+ GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
+ GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
+ GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
+ summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
+
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+ const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
+ const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
+
+ // 4-bit -> 8-bit
+ const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
+ const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+ const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
+ const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+ // load y
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
+
+ // mmla into int32x4_t
+ float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
+ GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
+ GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
+ GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
+ float32x4_t scale = vld1q_f32(_scale);
+
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
+ l1, r1)), l2, r2)), l3, r3))), scale);
+ }
+
+ float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+ sumv2 = vaddq_f32(sumv2, summs0);
+
+ vst1_f32(s, vget_low_f32 (sumv2));
+ vst1_f32(s + bs, vget_high_f32(sumv2));
+ return;
+ }
+#endif
+
+ int ib = 0;
+ float sumf = 0;
+
+ // TODO: add WASM SIMD
+#if defined(__ARM_NEON)
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+ float summs = 0;
+
+ for (; ib + 1 < nb; ib += 2) {
+ const block_q4_1 * restrict x0 = &x[ib + 0];
+ const block_q4_1 * restrict x1 = &x[ib + 1];
+ const block_q8_1 * restrict y0 = &y[ib + 0];
+ const block_q8_1 * restrict y1 = &y[ib + 1];
+
+ summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s) + GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s);
+
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+ // 4-bit -> 8-bit
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+ // load y
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+ // dot product into int32x4_t
+ const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
+ const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
+
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+ }
+
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
+#elif defined(__AVX2__) || defined(__AVX__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+
+ float summs = 0;
+
+ // Main loop
+ for (; ib < nb; ++ib) {
+ const float d0 = GGML_FP16_TO_FP32(x[ib].d);
+ const float d1 = GGML_FP16_TO_FP32(y[ib].d);
+
+ summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+ const __m256 d0v = _mm256_set1_ps( d0 );
+ const __m256 d1v = _mm256_set1_ps( d1 );
+
+ // Compute combined scales
+ const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
+
+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+ const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+ const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs );
+
+ const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
+
+ // Accumulate d0*d1*x*y
+#if defined(__AVX2__)
+ acc = _mm256_fmadd_ps( d0d1, xy, acc );
+#else
+ acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
+#endif
+ }
+
+ sumf = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+ for (; ib < nb; ++ib) {
+ // load elements
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
+
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
+
+ // mask and store lower part of x, and then upper part
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
+ }
+
+#elif defined(__POWER9_VECTOR__)
+ const vector signed char lowMask = vec_splats((signed char)0xF);
+ const vector signed int v0 = vec_splats((int32_t)0);
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+ vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 4
+ for (; ib < nb; ++ib) {
+ __builtin_prefetch(x[ib].qs, 0, 1);
+ __builtin_prefetch(y[ib].qs, 0, 1);
+
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+ vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m));
+ vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f};
+ vsumf0 = vec_madd(vxmin, vys, vsumf0);
+
+ vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+ vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+ vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+ vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask);
+ vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4);
+
+ vector signed int vsumi0 = v0;
+
+ vsumi0 = vec_msum(q8y0, q4x0, vsumi0);
+ vsumi0 = vec_msum(q8y1, q4x1, vsumi0);
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ }
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ sumf = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+ // Initialize accumulator with zeros
+ __m256 acc = (__m256)__lasx_xvldi(0);
+
+ float summs = 0;
+
+ // Main loop
+ for (; ib < nb; ++ib) {
+ const float d0 = GGML_FP16_TO_FP32(x[ib].d);
+ const float d1 = GGML_FP16_TO_FP32(y[ib].d);
+
+ summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+ const __m256 d0v = __lasx_xvreplfr2vr_s( d0 );
+ const __m256 d1v = __lasx_xvreplfr2vr_s( d1 );
+
+ // Compute combined scales
+ const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v );
+
+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+ const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+ const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0);
+
+ const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
+
+ // Accumulate d0*d1*x*y
+ acc = __lasx_xvfmadd_s( d0d1, xy, acc );
+ }
+
+ sumf = hsum_float_8(acc) + summs;
+#endif
+ for (; ib < nb; ++ib) {
+ int sumi = 0;
+
+ for (int j = 0; j < qk/2; ++j) {
+ const int v0 = (x[ib].qs[j] & 0x0F);
+ const int v1 = (x[ib].qs[j] >> 4);
+
+ sumi += (v0 * y[ib].qs[j]) + (v1 * y[ib].qs[j + qk/2]);
+ }
+
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
+ }
+
+ *s = sumf;
+}
+
+void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
+ return;
+ }
+#endif
+ const int qk = QK8_0;
+ const int nb = n / qk;
+
+ int ib = 0;
+ float sumf = 0;
+
+ assert(n % qk == 0);
+ assert(qk == QK5_0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_q5_0 * restrict x = vx;
+ const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+ uint32_t qh0;
+ uint32_t qh1;
+
+ uint64_t tmp0[4];
+ uint64_t tmp1[4];
+
+ for (; ib + 1 < nb; ib += 2) {
+ const block_q5_0 * restrict x0 = &x[ib];
+ const block_q5_0 * restrict x1 = &x[ib + 1];
+ const block_q8_0 * restrict y0 = &y[ib];
+ const block_q8_0 * restrict y1 = &y[ib + 1];
+
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+ // extract the 5th bit via lookup table ((!b) << 4)
+ memcpy(&qh0, x0->qh, sizeof(qh0));
+ memcpy(&qh1, x1->qh, sizeof(qh1));
+
+ tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
+ tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
+ tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
+ tmp0[3] = table_b2b_1[(qh0 >> 24) ];
+
+ tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
+ tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
+ tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
+ tmp1[3] = table_b2b_1[(qh1 >> 24) ];
+
+ const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
+ const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
+ const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
+ const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
+
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+ // 4-bit -> 8-bit
+ int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
+ int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+ int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
+ int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+ // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
+ const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
+ const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
+ const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
+ const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
+
+ // load y
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+ }
+
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__wasm_simd128__)
+ v128_t sumv = wasm_f32x4_splat(0.0f);
+
+ uint32_t qh;
+ uint64_t tmp[4];
+
+ // TODO: check if unrolling this is better
+ for (; ib < nb; ++ib) {
+ const block_q5_0 * restrict x0 = &x[ib];
+ const block_q8_0 * restrict y0 = &y[ib];
+
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
+
+ // extract the 5th bit
+ memcpy(&qh, x0->qh, sizeof(qh));
+
+ tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
+ tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
+ tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
+ tmp[3] = table_b2b_1[(qh >> 24) ];
+
+ const v128_t qhl = wasm_v128_load(tmp + 0);
+ const v128_t qhh = wasm_v128_load(tmp + 2);
+
+ const v128_t v0 = wasm_v128_load(x0->qs);
+
+ // 4-bit -> 8-bit
+ const v128_t v0l = wasm_v128_and (v0, m4b);
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
+
+ // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
+ const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
+ const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
+
+ // load y
+ const v128_t v1l = wasm_v128_load(y0->qs);
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
+
+ // int8x16 -> int16x8
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
+
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
+
+ // dot product
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
+ wasm_i32x4_add(
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+ wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
+ }
+
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
+#elif defined(__AVX2__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+
+ // Main loop
+ for (; ib < nb; ++ib) {
+ /* Compute combined scale for the block */
+ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+ __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+ bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
+ qx = _mm256_or_si256(qx, bxhi);
+
+ __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+ /* Multiply q with scale and accumulate */
+ acc = _mm256_fmadd_ps(d, q, acc);
+ }
+
+ sumf = hsum_float_8(acc);
+#elif defined(__AVX__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+ __m128i mask = _mm_set1_epi8((char)0xF0);
+
+ // Main loop
+ for (; ib < nb; ++ib) {
+ /* Compute combined scale for the block */
+ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+
+ __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
+ const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+ __m128i bxhil = _mm256_castsi256_si128(bxhi);
+ __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+ bxhil = _mm_andnot_si128(bxhil, mask);
+ bxhih = _mm_andnot_si128(bxhih, mask);
+ __m128i bxl = _mm256_castsi256_si128(bx_0);
+ __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
+ bxl = _mm_or_si128(bxl, bxhil);
+ bxh = _mm_or_si128(bxh, bxhih);
+ bx_0 = MM256_SET_M128I(bxh, bxl);
+
+ const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+ const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);
+
+ /* Multiply q with scale and accumulate */
+ acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
+ }
+
+ sumf = hsum_float_8(acc);
+#elif defined(__riscv_v_intrinsic)
+ uint32_t qh;
+
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+ // These temporary registers are for masking and shift operations
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+ vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
+
+ vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
+ vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+
+ for (; ib < nb; ++ib) {
+ memcpy(&qh, x[ib].qh, sizeof(uint32_t));
+
+ // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+
+ // ((qh & (1u << (j + 16))) >> (j + 12));
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
+ vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
+
+ // narrowing
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
+
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
+
+ // load
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
+
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
+
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
+
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
+
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
+ }
+
+#elif defined(__POWER9_VECTOR__)
+ const vector signed char lowMask = vec_splats((signed char)0xF);
+ const vector unsigned char v4 = vec_splats((unsigned char)4);
+
+ vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 4
+ for (; ib < nb; ++ib) {
+ __builtin_prefetch(x[ib].qs, 0, 1);
+ __builtin_prefetch(y[ib].qs, 0, 1);
+
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+ vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[ib].qh[0]]), (uint64_t)(table_b2b_1[x[ib].qh[1]])};
+ vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[ib].qh[2]]), (uint64_t)(table_b2b_1[x[ib].qh[3]])};
+
+ vector signed char qh0 = (vector signed char)aux64x2_0;
+ vector signed char qh1 = (vector signed char)aux64x2_1;
+
+ vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+
+ vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0);
+ vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1);
+
+ vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+ vector signed char q8y1 = vec_xl( 16, y[ib].qs);
+
+ vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0));
+ vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1));
+
+ qv0 = vec_add(qv0, qv1);
+
+ vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0));
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ }
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ sumf = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+ // Initialize accumulator with zeros
+ __m256 acc = (__m256)__lasx_xvldi(0);
+
+ // Main loop
+ for (; ib < nb; ++ib) {
+ /* Compute combined scale for the block */
+ const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); //FIXME
+
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+ __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+ bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0));
+ qx = __lasx_xvor_v(qx, bxhi);
+
+ __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
+
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+ /* Multiply q with scale and accumulate */
+ acc = __lasx_xvfmadd_s(d, q, acc);
+ }
+
+ sumf = hsum_float_8(acc);
+#endif
+ for (; ib < nb; ++ib) {
+ uint32_t qh;
+ memcpy(&qh, x[ib].qh, sizeof(qh));
+
+ int sumi = 0;
+
+ for (int j = 0; j < qk/2; ++j) {
+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
+ const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
+
+ const int32_t x0 = ((x[ib].qs[j] & 0x0F) | xh_0) - 16;
+ const int32_t x1 = ((x[ib].qs[j] >> 4) | xh_1) - 16;
+
+ sumi += (x0 * y[ib].qs[j]) + (x1 * y[ib].qs[j + qk/2]);
+ }
+
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
+ }
+
+ *s = sumf;
+}
+
+void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1)) {
+ return;
+ }
+#endif
+ const int qk = QK8_1;
+ const int nb = n / qk;
+
+ int ib = 0;
+ float sumf = 0;
+
+ assert(n % qk == 0);
+ assert(qk == QK5_1);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_q5_1 * restrict x = vx;
+ const block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+ float summs0 = 0.0f;
+ float summs1 = 0.0f;
+
+ uint32_t qh0;
+ uint32_t qh1;
+
+ uint64_t tmp0[4];
+ uint64_t tmp1[4];
+
+ for (; ib + 1 < nb; ib += 2) {
+ const block_q5_1 * restrict x0 = &x[ib];
+ const block_q5_1 * restrict x1 = &x[ib + 1];
+ const block_q8_1 * restrict y0 = &y[ib];
+ const block_q8_1 * restrict y1 = &y[ib + 1];
+
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+ summs0 += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s);
+ summs1 += GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s);
+
+ // extract the 5th bit via lookup table ((b) << 4)
+ memcpy(&qh0, x0->qh, sizeof(qh0));
+ memcpy(&qh1, x1->qh, sizeof(qh1));
+
+ tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
+ tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
+ tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
+ tmp0[3] = table_b2b_0[(qh0 >> 24) ];
+
+ tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
+ tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
+ tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
+ tmp1[3] = table_b2b_0[(qh1 >> 24) ];
+
+ const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
+ const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
+ const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
+ const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
+
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+ // 4-bit -> 8-bit
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+ // add high bit
+ const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
+ const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
+ const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
+ const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
+
+ // load y
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
+ ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+ }
+
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
+#elif defined(__wasm_simd128__)
+ v128_t sumv = wasm_f32x4_splat(0.0f);
+
+ float summs = 0.0f;
+
+ uint32_t qh;
+ uint64_t tmp[4];
+
+ // TODO: check if unrolling this is better
+ for (; ib < nb; ++ib) {
+ const block_q5_1 * restrict x0 = &x[ib];
+ const block_q8_1 * restrict y0 = &y[ib];
+
+ summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s);
+
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
+
+ // extract the 5th bit
+ memcpy(&qh, x0->qh, sizeof(qh));
+
+ tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
+ tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
+ tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
+ tmp[3] = table_b2b_0[(qh >> 24) ];
+
+ const v128_t qhl = wasm_v128_load(tmp + 0);
+ const v128_t qhh = wasm_v128_load(tmp + 2);
+
+ const v128_t v0 = wasm_v128_load(x0->qs);
+
+ // 4-bit -> 8-bit
+ const v128_t v0l = wasm_v128_and (v0, m4b);
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
+
+ // add high bit
+ const v128_t v0lf = wasm_v128_or(v0l, qhl);
+ const v128_t v0hf = wasm_v128_or(v0h, qhh);
+
+ // load y
+ const v128_t v1l = wasm_v128_load(y0->qs);
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
+
+ // int8x16 -> int16x8
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
+
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
+
+ // dot product
+ sumv = wasm_f32x4_add(sumv,
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+ wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
+ }
+
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
+#elif defined(__AVX2__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+
+ float summs = 0.0f;
+
+ // Main loop
+ for (; ib < nb; ++ib) {
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d));
+
+ summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+ __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+ bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
+ qx = _mm256_or_si256(qx, bxhi);
+
+ const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d));
+ const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+ const __m256 q = mul_sum_us8_pairs_float(qx, qy);
+
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
+ }
+
+ sumf = hsum_float_8(acc) + summs;
+#elif defined(__AVX__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+ __m128i mask = _mm_set1_epi8(0x10);
+
+ float summs = 0.0f;
+
+ // Main loop
+ for (; ib < nb; ++ib) {
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d));
+
+ summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+ __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
+ const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+ __m128i bxhil = _mm256_castsi256_si128(bxhi);
+ __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+ bxhil = _mm_and_si128(bxhil, mask);
+ bxhih = _mm_and_si128(bxhih, mask);
+ __m128i bxl = _mm256_castsi256_si128(bx_0);
+ __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
+ bxl = _mm_or_si128(bxl, bxhil);
+ bxh = _mm_or_si128(bxh, bxhih);
+ bx_0 = MM256_SET_M128I(bxh, bxl);
+
+ const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d));
+ const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+ const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);
+
+ acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
+ }
+
+ sumf = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+ uint32_t qh;
+
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+ // temporary registers for shift operations
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+ vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+
+ for (; ib < nb; ++ib) {
+ memcpy(&qh, x[ib].qh, sizeof(uint32_t));
+
+ // load qh
+ vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
+
+ // ((qh >> (j + 0)) << 4) & 0x10;
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
+
+ // ((qh >> (j + 12)) ) & 0x10;
+ vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
+
+ // narrowing
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
+
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
+
+ // load
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
+
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
+
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
+
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
+ }
+
+#elif defined(__POWER9_VECTOR__)
+ const vector signed char lowMask = vec_splats((signed char)0xF);
+ const vector signed int v0 = vec_splats((int32_t)0);
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+ vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 4
+ for (; ib < nb; ++ib) {
+ __builtin_prefetch(x[ib].qs, 0, 1);
+ __builtin_prefetch(y[ib].qs, 0, 1);
+
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+ vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m));
+ vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.f, 0.f, 0.f};
+ vsumf0 = vec_madd(vxmin, vys, vsumf0);
+
+ vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[ib].qh[0]]), (uint64_t)(table_b2b_0[x[ib].qh[1]])};
+ vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[ib].qh[2]]), (uint64_t)(table_b2b_0[x[ib].qh[3]])};
+
+ vector signed char qh0 = (vector signed char)aux64x2_0;
+ vector signed char qh1 = (vector signed char)aux64x2_1;
+
+ vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+
+ vector unsigned char q5x0 = (vector unsigned char)vec_or(vec_and(qxs, lowMask), qh0);
+ vector unsigned char q5x1 = (vector unsigned char)vec_or(vec_sr(qxs, v4), qh1);
+
+ vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+ vector signed char q8y1 = vec_xl( 16, y[ib].qs);
+
+ vector signed int vsumi0 = v0;
+
+ vsumi0 = vec_msum(q8y0, q5x0, vsumi0);
+ vsumi0 = vec_msum(q8y1, q5x1, vsumi0);
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ }
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ sumf = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+ // Initialize accumulator with zeros
+ __m256 acc = (__m256)__lasx_xvldi(0);
+
+ float summs = 0.0f;
+
+ // Main loop
+ for (; ib < nb; ++ib) {
+ const __m256 dx = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d));
+
+ summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+ __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+ bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10));
+ qx = __lasx_xvor_v(qx, bxhi);
+
+ const __m256 dy = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib].d));
+ const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
+
+ const __m256 q = mul_sum_us8_pairs_float(qx, qy);
+
+ acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc);
+ }
+
+ sumf = hsum_float_8(acc) + summs;
+#endif
+ for (; ib < nb; ++ib) {
+ uint32_t qh;
+ memcpy(&qh, x[ib].qh, sizeof(qh));
+
+ int sumi = 0;
+
+ for (int j = 0; j < qk/2; ++j) {
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
+
+ const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
+ const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1;
+
+ sumi += (x0 * y[ib].qs[j]) + (x1 * y[ib].qs[j + qk/2]);
+ }
+
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
+ }
+
+ *s = sumf;
+}
+
+void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
+ return;
+ }
+#endif
+ const int qk = QK8_0;
+ const int nb = n / qk;
+
+ assert(n % qk == 0);
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ assert((nrc == 2) || (nrc == 1));
+#else
+ assert(nrc == 1);
+#endif
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_q8_0 * restrict x = vx;
+ const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ if (nrc == 2) {
+ const block_q8_0 * restrict vx0 = vx;
+ const block_q8_0 * restrict vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx);
+ const block_q8_0 * restrict vy0 = vy;
+ const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
+
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+
+ for (int i = 0; i < nb; i++) {
+ const block_q8_0 * restrict b_x0 = &vx0[i];
+ const block_q8_0 * restrict b_y0 = &vy0[i];
+
+ const block_q8_0 * restrict b_x1 = &vx1[i];
+ const block_q8_0 * restrict b_y1 = &vy1[i];
+
+ const int8x16_t x0_l = vld1q_s8(b_x0->qs);
+ const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
+ const int8x16_t x1_l = vld1q_s8(b_x1->qs);
+ const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
+
+ // load y
+ const int8x16_t y0_l = vld1q_s8(b_y0->qs);
+ const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
+ const int8x16_t y1_l = vld1q_s8(b_y1->qs);
+ const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
+
+ float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+ GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+ GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
+ float32x4_t scale = vld1q_f32(_scale);
+
+ int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+ int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+
+ int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+ int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+
+ int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+ int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+
+ int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+ int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+
+ sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
+ l1, r1)), l2, r2)), l3, r3))), scale);
+ }
+ float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
+ float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+
+ vst1_f32(s, vget_low_f32(sumv2));
+ vst1_f32(s + bs, vget_high_f32(sumv2));
+ return;
+ }
+#endif
+
+ int ib = 0;
+ float sumf = 0;
+
+#if defined(__ARM_FEATURE_SVE)
+ if (svcntb() == QK8_0) {
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
+
+ for (; ib + 1 < nb; ib += 2) {
+ const block_q8_0 * restrict x0 = &x[ib + 0];
+ const block_q8_0 * restrict x1 = &x[ib + 1];
+ const block_q8_0 * restrict y0 = &y[ib + 0];
+ const block_q8_0 * restrict y1 = &y[ib + 1];
+
+ // load x
+ const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
+ const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
+
+ // load y
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+ }
+
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+ }
+#elif defined(__ARM_NEON)
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+ for (; ib + 1 < nb; ib += 2) {
+ const block_q8_0 * restrict x0 = &x[ib + 0];
+ const block_q8_0 * restrict x1 = &x[ib + 1];
+ const block_q8_0 * restrict y0 = &y[ib + 0];
+ const block_q8_0 * restrict y1 = &y[ib + 1];
+
+ const int8x16_t x0_0 = vld1q_s8(x0->qs);
+ const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
+ const int8x16_t x1_0 = vld1q_s8(x1->qs);
+ const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
+
+ // load y
+ const int8x16_t y0_0 = vld1q_s8(y0->qs);
+ const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
+ const int8x16_t y1_0 = vld1q_s8(y1->qs);
+ const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
+
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+ ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
+ ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+ ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
+ ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+ }
+
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__) || defined(__AVX__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+
+ // Main loop
+ for (; ib < nb; ++ib) {
+ // Compute combined scale for the block
+ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+ __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs);
+ __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+ // Multiply q with scale and accumulate
+#if defined(__AVX2__)
+ acc = _mm256_fmadd_ps( d, q, acc );
+#else
+ acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
+#endif
+ }
+
+ sumf = hsum_float_8(acc);
+#elif defined(__riscv_v_intrinsic)
+ size_t vl = __riscv_vsetvl_e8m1(qk);
+
+ for (; ib < nb; ++ib) {
+ // load elements
+ vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl);
+ vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
+
+ vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl);
+
+ vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
+ vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
+
+ int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
+
+ sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
+ }
+#elif defined(__POWER9_VECTOR__)
+ const vector signed int v0 = vec_splats((int32_t)0);
+ vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 8
+ for (; ib < nb; ++ib) {
+ __builtin_prefetch(x[ib].qs, 0, 1);
+ __builtin_prefetch(y[ib].qs, 0, 1);
+
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+ vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector signed char q8x0 = vec_xl( 0, x[ib].qs);
+ vector signed char q8x1 = vec_xl(16, x[ib].qs);
+ vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+ vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+ vector signed short qv0 = vec_mule(q8x0, q8y0);
+ vector signed short qv1 = vec_mulo(q8x0, q8y0);
+ vector signed short qv2 = vec_mule(q8x1, q8y1);
+ vector signed short qv3 = vec_mulo(q8x1, q8y1);
+
+ vector signed int vsumi0 = v0;
+ vector signed int vsumi1 = v0;
+
+ vsumi0 = vec_sum4s(qv0, vsumi0);
+ vsumi1 = vec_sum4s(qv1, vsumi1);
+ vsumi0 = vec_sum4s(qv2, vsumi0);
+ vsumi1 = vec_sum4s(qv3, vsumi1);
+
+ vsumi0 = vec_add(vsumi0, vsumi1);
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ }
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ sumf = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+ // Initialize accumulator with zeros
+ __m256 acc = (__m256)__lasx_xvldi(0);
+
+ // Main loop
+ for (; ib < nb; ++ib) {
+ // Compute combined scale for the block
+ const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+ __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0);
+ __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
+
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+ // Multiply q with scale and accumulate
+ acc = __lasx_xvfmadd_s( d, q, acc );
+ }
+
+ sumf = hsum_float_8(acc);
+#endif
+ for (; ib < nb; ++ib) {
+ int sumi = 0;
+
+ for (int j = 0; j < qk; j++) {
+ sumi += x[ib].qs[j]*y[ib].qs[j];
+ }
+
+ sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
+ }
+
+ *s = sumf;
+}
+
+void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_q2_K * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+ const uint8x16_t m3 = vdupq_n_u8(0x3);
+ const uint8x16_t m4 = vdupq_n_u8(0xF);
+
+ const int32x4_t vzero = vdupq_n_s32(0);
+
+ ggml_int8x16x2_t q2bytes;
+ uint8_t aux[16];
+
+ float sum = 0;
+
+ for (int i = 0; i < nb; ++i) {
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ const uint8_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+ const uint8_t * restrict sc = x[i].scales;
+
+ const uint8x16_t mins_and_scales = vld1q_u8(sc);
+ const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
+ vst1q_u8(aux, scales);
+
+ const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
+ const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
+ const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};
+ const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
+ vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
+ const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
+ vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));
+ sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));
+
+ int isum = 0;
+ int is = 0;
+
+// We use this macro instead of a function call because for some reason
+// the code runs 2-3% slower, even if the function is declared inline
+#define MULTIPLY_ACCUM_WITH_SCALE(index)\
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
+
+#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
+ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
+ q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
+ q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
+ MULTIPLY_ACCUM_WITH_SCALE((index));
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
+
+ ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+ q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
+ q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
+
+ MULTIPLY_ACCUM_WITH_SCALE(0);
+
+ SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
+ SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
+ SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
+
+ is += 8;
+ }
+
+ sum += d * isum;
+ }
+
+ *s = sum;
+
+#elif defined __AVX2__
+
+ const __m256i m3 = _mm256_set1_epi8(3);
+ const __m128i m4 = _mm_set1_epi8(0xF);
+
+ __m256 acc = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ const uint8_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+ const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
+ const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
+ const __m256i mins = _mm256_cvtepi8_epi16(mins8);
+ const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums));
+
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
+
+ const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
+ const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
+ const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
+ const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
+
+ __m256i sumi = _mm256_setzero_si256();
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32;
+
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+ const __m256i q2_0 = _mm256_and_si256(q2bits, m3);
+ const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
+ const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
+ const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
+
+ __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
+ __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
+ __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2);
+ __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3);
+
+ p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0);
+ p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1);
+ p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2);
+ p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3);
+
+ p0 = _mm256_add_epi32(p0, p1);
+ p2 = _mm256_add_epi32(p2, p3);
+
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
+ }
+
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+
+ }
+
+ *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+ const __m128i m3 = _mm_set1_epi8(0x3);
+ const __m128i m4 = _mm_set1_epi8(0xF);
+ const __m128i m2 = _mm_set1_epi8(0x2);
+
+ __m256 acc = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ const uint8_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ // load mins and scales from block_q2_K.scales[QK_K/16]
+ const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+ const __m128i scales16 = _mm_and_si128(mins_and_scales, m4);
+ const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
+ const __m128i mins_0 = _mm_cvtepi8_epi16(mins16);
+ const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16));
+
+ // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2
+ const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0]));
+ const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
+
+ // sumf += -dmin * summs in 32bits*8
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
+
+ const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
+ const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
+ const __m128i scales[2] = { scales_0, scales_1 };
+
+ __m128i sumi_0 = _mm_setzero_si128();
+ __m128i sumi_1 = _mm_setzero_si128();
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K]
+ const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+ // load 2bits*16*8 from block_q2_K.qs[QK_K/4]
+ __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
+ const __m128i q2_0 = _mm_and_si128(q2bits, m3);
+ const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
+ const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
+ const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
+ q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
+ const __m128i q2_1 = _mm_and_si128(q2bits, m3);
+ const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
+ const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
+ const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
+
+ // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8
+ __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0);
+ __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1);
+ __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2);
+ __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3);
+ __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4);
+ __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5);
+ __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6);
+ __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7);
+
+ // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8
+ __m128i shuffle = _mm_set1_epi16(0x0100);
+ p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7);
+
+ p0 = _mm_add_epi32(p0, p1);
+ p2 = _mm_add_epi32(p2, p3);
+ p4 = _mm_add_epi32(p4, p5);
+ p6 = _mm_add_epi32(p6, p7);
+
+ // isum in 32bits*4*2
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2));
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6));
+ }
+
+ // sumf += dall * isum - dmin * summs in 32bits
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
+ }
+
+ *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+ float sumf = 0;
+ uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
+
+ for (int i = 0; i < nb; ++i) {
+
+ const uint8_t * q2 = x[i].qs;
+ const int8_t * q8 = y[i].qs;
+ const uint8_t * sc = x[i].scales;
+
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ size_t vl = 16;
+
+ vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
+ vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
+
+ vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
+
+ vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
+ vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
+ vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
+ vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
+ vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+
+ sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
+
+ vl = 32;
+
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+ vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
+
+ uint8_t is=0;
+ int isum=0;
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ // load Q2
+ vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
+
+ vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
+ vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl);
+ vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl);
+ vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl);
+
+ // duplicate scale elements for product
+ vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl);
+ vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl);
+ vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl);
+ vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl);
+
+ vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
+ vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
+ vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
+ vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
+
+ // load Q8
+ vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
+ vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
+ vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl);
+ vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl);
+
+ vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
+ vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
+ vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
+ vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
+
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
+
+ isum += __riscv_vmv_x_s_i32m1_i32(isum1);
+
+ q2+=32; q8+=128; is=8;
+
+ }
+
+ sumf += dall * isum;
+
+ }
+
+ *s = sumf;
+
+#elif defined(__POWER9_VECTOR__)
+ const vector signed char lowMask = vec_splats((signed char)0x3);
+ const vector signed char lowScaleMask = vec_splats((signed char)0xF);
+ const vector int v0 = vec_splats((int32_t)0);
+ const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+ const vector unsigned char v6 = vec_splats((unsigned char)0x6);
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+ vector float vsumf0 = vec_splats(0.0f);
+ vector float vsumf1 = vec_splats(0.0f);
+ vector float vsumf2 = vec_splats(0.0f);
+ vector float vsumf3 = vec_splats(0.0f);
+
+ for (int i = 0; i < nb; ++i) {
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+ vector float vyd = vec_splats(y[i].d);
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin));
+ vector float vdmin = vec_mul(vxmin, vyd);
+
+ vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
+ vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
+
+ vector signed char q2xmins = (vector signed char)vec_xl( 0, x[i].scales);
+ vector signed char vscales = vec_and(q2xmins, lowScaleMask);
+
+ q2xmins = vec_sr(q2xmins, v4);
+ vector signed short q2xmins0 = vec_unpackh(q2xmins);
+ vector signed short q2xmins1 = vec_unpackl(q2xmins);
+
+ vector signed int prod0 = vec_mule(q2xmins0, q8ysums0);
+ vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0);
+ vector signed int prod2 = vec_mule(q2xmins1, q8ysums1);
+ vector signed int prod3 = vec_mulo(q2xmins1, q8ysums1);
+
+ vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
+ vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
+ vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
+ vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
+
+ vector signed int vsumi0 = v0;
+ vector signed int vsumi1 = v0;
+ vector signed int vsumi2 = v0;
+ vector signed int vsumi3 = v0;
+ vector signed int vsumi4 = v0;
+ vector signed int vsumi5 = v0;
+ vector signed int vsumi6 = v0;
+ vector signed int vsumi7 = v0;
+
+ const uint8_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ __builtin_prefetch(q2, 0, 1);
+ __builtin_prefetch(q8, 0, 1);
+
+ vector signed char qxs0 = (vector signed char)vec_xl( 0, q2);
+ vector signed char qxs1 = (vector signed char)vec_xl(16, q2);
+ q2 += 32;
+
+ vector unsigned char q2x00 = (vector unsigned char)vec_and(qxs0, lowMask);
+ vector unsigned char q2x01 = (vector unsigned char)vec_and(vec_sr(qxs0, v2), lowMask);
+ vector unsigned char q2x02 = (vector unsigned char)vec_and(vec_sr(qxs0, v4), lowMask);
+ vector unsigned char q2x03 = (vector unsigned char)vec_and(vec_sr(qxs0, v6), lowMask);
+ vector unsigned char q2x10 = (vector unsigned char)vec_and(qxs1, lowMask);
+ vector unsigned char q2x11 = (vector unsigned char)vec_and(vec_sr(qxs1, v2), lowMask);
+ vector unsigned char q2x12 = (vector unsigned char)vec_and(vec_sr(qxs1, v4), lowMask);
+ vector unsigned char q2x13 = (vector unsigned char)vec_and(vec_sr(qxs1, v6), lowMask);
+
+ vector signed char q8y00 = vec_xl( 0, q8);
+ vector signed char q8y10 = vec_xl( 16, q8);
+ vector signed char q8y01 = vec_xl( 32, q8);
+ vector signed char q8y11 = vec_xl( 48, q8);
+ vector signed char q8y02 = vec_xl( 64, q8);
+ vector signed char q8y12 = vec_xl( 80, q8);
+ vector signed char q8y03 = vec_xl( 96, q8);
+ vector signed char q8y13 = vec_xl(112, q8);
+ q8 += 128;
+
+ vector signed int qv0 = vec_msum(q8y00, q2x00, v0);
+ vector signed int qv1 = vec_msum(q8y01, q2x01, v0);
+ vector signed int qv2 = vec_msum(q8y02, q2x02, v0);
+ vector signed int qv3 = vec_msum(q8y03, q2x03, v0);
+ vector signed int qv4 = vec_msum(q8y10, q2x10, v0);
+ vector signed int qv5 = vec_msum(q8y11, q2x11, v0);
+ vector signed int qv6 = vec_msum(q8y12, q2x12, v0);
+ vector signed int qv7 = vec_msum(q8y13, q2x13, v0);
+
+ vector signed short vscales_07 = vec_unpackh(vscales);
+ vector signed int vscales_03 = vec_unpackh(vscales_07);
+ vector signed int vscales_47 = vec_unpackl(vscales_07);
+ vector signed int vs0 = vec_splat(vscales_03, 0);
+ vector signed int vs1 = vec_splat(vscales_03, 1);
+ vector signed int vs2 = vec_splat(vscales_03, 2);
+ vector signed int vs3 = vec_splat(vscales_03, 3);
+ vector signed int vs4 = vec_splat(vscales_47, 0);
+ vector signed int vs5 = vec_splat(vscales_47, 1);
+ vector signed int vs6 = vec_splat(vscales_47, 2);
+ vector signed int vs7 = vec_splat(vscales_47, 3);
+ vscales = vec_sld(vscales, vscales, 8);
+
+ vsumi0 = vec_add(vec_mul(qv0, vs0), vsumi0);
+ vsumi1 = vec_add(vec_mul(qv1, vs2), vsumi1);
+ vsumi2 = vec_add(vec_mul(qv2, vs4), vsumi2);
+ vsumi3 = vec_add(vec_mul(qv3, vs6), vsumi3);
+ vsumi4 = vec_add(vec_mul(qv4, vs1), vsumi4);
+ vsumi5 = vec_add(vec_mul(qv5, vs3), vsumi5);
+ vsumi6 = vec_add(vec_mul(qv6, vs5), vsumi6);
+ vsumi7 = vec_add(vec_mul(qv7, vs7), vsumi7);
+ }
+
+ vsumi0 = vec_add(vsumi0, vsumi4);
+ vsumi1 = vec_add(vsumi1, vsumi5);
+ vsumi2 = vec_add(vsumi2, vsumi6);
+ vsumi3 = vec_add(vsumi3, vsumi7);
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+ vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+ vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+ }
+
+ vsumf0 = vec_add(vsumf0, vsumf2);
+ vsumf1 = vec_add(vsumf1, vsumf3);
+
+ vsumf0 = vec_add(vsumf0, vsumf1);
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ *s = vec_extract(vsumf0, 0);
+
+#elif defined __loongarch_asx
+
+ const __m256i m3 = __lasx_xvreplgr2vr_b(3);
+ const __m128i m4 = __lsx_vreplgr2vr_b(0xF);
+
+ __m256 acc = (__m256)__lasx_xvldi(0);
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ const uint8_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
+ const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
+ const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
+ const __m256i mins = lasx_ext8_16(mins8);
+ const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));
+
+ acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);
+
+ const __m256i all_scales = lasx_ext8_16(scales8);
+ const __m128i l_scales = lasx_extracti128(all_scales, 0);
+ const __m128i h_scales = lasx_extracti128(all_scales, 1);
+ const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
+
+ __m256i sumi = __lasx_xvldi(0);
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ const __m256i q2bits = __lasx_xvld((const __m256i*)q2, 0); q2 += 32;
+
+ const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+
+ const __m256i q2_0 = __lasx_xvand_v(q2bits, m3);
+ const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3);
+ const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3);
+ const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3);
+
+ __m256i p0 = lasx_maddubs_h(q2_0, q8_0);
+ __m256i p1 = lasx_maddubs_h(q2_1, q8_1);
+ __m256i p2 = lasx_maddubs_h(q2_2, q8_2);
+ __m256i p3 = lasx_maddubs_h(q2_3, q8_3);
+
+ p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0);
+ p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1);
+ p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2);
+ p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3);
+
+ p0 = __lasx_xvadd_w(p0, p1);
+ p2 = __lasx_xvadd_w(p2, p3);
+
+ sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p0, p2));
+ }
+
+ acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
+
+ }
+
+ *s = hsum_float_8(acc);
+
+#else
+
+ float sumf = 0;
+
+ for (int i = 0; i < nb; ++i) {
+
+ const uint8_t * q2 = x[i].qs;
+ const int8_t * q8 = y[i].qs;
+ const uint8_t * sc = x[i].scales;
+
+ int summs = 0;
+ for (int j = 0; j < 16; ++j) {
+ summs += y[i].bsums[j] * (sc[j] >> 4);
+ }
+
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ int isum = 0;
+ int is = 0;
+ int d;
+ for (int k = 0; k < QK_K/128; ++k) {
+ int shift = 0;
+ for (int j = 0; j < 4; ++j) {
+ d = sc[is++] & 0xF;
+ int isuml = 0;
+ for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
+ isum += d * isuml;
+ d = sc[is++] & 0xF;
+ isuml = 0;
+ for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
+ isum += d * isuml;
+ shift += 2;
+ q8 += 32;
+ }
+ q2 += 32;
+ }
+ sumf += dall * isum - dmin * summs;
+ }
+ *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const uint32_t kmask1 = 0x03030303;
+ const uint32_t kmask2 = 0x0f0f0f0f;
+
+ const block_q3_K * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+ uint32_t aux[3];
+ uint32_t utmp[4];
+
+ const uint8x16_t m3b = vdupq_n_u8(0x3);
+ const int32x4_t vzero = vdupq_n_s32(0);
+
+ const uint8x16_t m0 = vdupq_n_u8(1);
+ const uint8x16_t m1 = vshlq_n_u8(m0, 1);
+ const uint8x16_t m2 = vshlq_n_u8(m0, 2);
+ const uint8x16_t m3 = vshlq_n_u8(m0, 3);
+ const int8_t m32 = 32;
+
+ ggml_int8x16x4_t q3bytes;
+
+ float sum = 0;
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict qh = x[i].hmask;
+ const int8_t * restrict q8 = y[i].qs;
+
+ ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
+
+ ggml_uint8x16x4_t q3h;
+
+ int32_t isum = 0;
+
+ // Set up scales
+ memcpy(aux, x[i].scales, 12);
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+
+ int8_t * scale = (int8_t *)utmp;
+ for (int j = 0; j < 16; ++j) scale[j] -= m32;
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
+ const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
+ const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+ q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
+ q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
+ q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
+ q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
+
+ q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
+ q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
+ q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
+ q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
+
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
+
+ scale += 4;
+
+ q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
+ q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
+ q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
+ q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
+
+ q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
+ q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
+ q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
+ q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
+
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
+
+ scale += 4;
+
+ if (j == 0) {
+ qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
+ qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
+ }
+
+ }
+ sum += d * isum;
+
+ }
+
+ *s = sum;
+
+#elif defined __AVX2__
+
+ const __m256i m3 = _mm256_set1_epi8(3);
+ const __m256i mone = _mm256_set1_epi8(1);
+ const __m128i m32 = _mm_set1_epi8(32);
+
+ __m256 acc = _mm256_setzero_ps();
+
+ uint32_t aux[3];
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+ const uint8_t * restrict q3 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ // Set up scales
+ memcpy(aux, x[i].scales, 12);
+ __m128i scales128 = _mm_set_epi32(
+ ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+ ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+ (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+ (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+ scales128 = _mm_sub_epi8(scales128, m32);
+ const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
+ const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
+ const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
+ const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
+
+ // high bit
+ const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
+
+ // integer accumulator
+ __m256i sumi = _mm256_setzero_si256();
+
+ int bit = 0;
+ int is = 0;
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ // load low 2 bits
+ const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
+
+ // prepare low and high bits
+ const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);
+ const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+ ++bit;
+
+ const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
+ const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+ ++bit;
+
+ const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
+ const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+ ++bit;
+
+ const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
+ const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+ ++bit;
+
+ // load Q8 quants
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+ // and 2 if the high bit was set)
+ __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
+ __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
+ __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
+ __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
+
+ __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
+ __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
+ __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
+ __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
+
+ p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+ p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+ p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
+ p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
+
+ // multiply with scales
+ p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
+ p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
+ p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
+ p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
+
+ // accumulate
+ p16_0 = _mm256_add_epi32(p16_0, p16_1);
+ p16_2 = _mm256_add_epi32(p16_2, p16_3);
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
+
+ }
+
+ // multiply with block scale and accumulate
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+
+ }
+
+ *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+ const __m128i m3 = _mm_set1_epi8(3);
+ const __m128i mone = _mm_set1_epi8(1);
+ const __m128i m32 = _mm_set1_epi8(32);
+ const __m128i m2 = _mm_set1_epi8(2);
+
+ __m256 acc = _mm256_setzero_ps();
+
+ const uint32_t *aux;
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+ const uint8_t * restrict q3 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ // Set up scales
+ aux = (const uint32_t *)x[i].scales;
+ __m128i scales128 = _mm_set_epi32(
+ ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+ ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+ (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+ (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+ scales128 = _mm_sub_epi8(scales128, m32);
+ const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
+ const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
+ const __m128i scales[2] = { scales_0, scales_1 };
+
+ // high bit *128*2 from block_q3_K.hmask[QK_K/8]
+ const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
+ const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
+
+ // integer accumulator
+ __m128i sumi_0 = _mm_setzero_si128();
+ __m128i sumi_1 = _mm_setzero_si128();
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
+ const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
+ const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
+
+ // prepare low and high bits
+ const int bit = j << 2;
+
+ const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
+ const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
+ const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
+ const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
+
+ const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
+ const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
+ const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
+ const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
+
+ const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
+ const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
+ const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
+ const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
+
+ const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
+ const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
+ const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
+ const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
+
+ // load Q8 quants from block_q8_K.qs[QK_K]
+ const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+ // and 2 if the high bit was set)
+ __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
+ __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
+ __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
+ __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
+ __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
+ __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
+ __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
+ __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
+
+ __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
+ __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
+ __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
+ __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
+ __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
+ __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
+ __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
+ __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
+
+ p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+ p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+ p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+ p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+ p16_4 = _mm_sub_epi16(p16_4, q8s_4);
+ p16_5 = _mm_sub_epi16(p16_5, q8s_5);
+ p16_6 = _mm_sub_epi16(p16_6, q8s_6);
+ p16_7 = _mm_sub_epi16(p16_7, q8s_7);
+
+ // multiply with scales
+ __m128i shuffle = _mm_set1_epi16(0x0100);
+ p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
+
+ // accumulate
+ p16_0 = _mm_add_epi32(p16_0, p16_1);
+ p16_2 = _mm_add_epi32(p16_2, p16_3);
+ p16_4 = _mm_add_epi32(p16_4, p16_5);
+ p16_6 = _mm_add_epi32(p16_6, p16_7);
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
+
+ }
+
+ // multiply with block scale and accumulate
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
+
+ }
+
+ *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+ uint32_t aux[3];
+ uint32_t utmp[4];
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict qh = x[i].hmask;
+ const int8_t * restrict q8 = y[i].qs;
+
+ memcpy(aux, x[i].scales, 12);
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+
+ int8_t * scale = (int8_t *)utmp;
+ for (int j = 0; j < 16; ++j) scale[j] -= 32;
+
+
+ size_t vl = 32;
+ uint8_t m = 1;
+
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+ vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
+
+ int sum_t = 0;
+
+ for (int j = 0; j < QK_K; j += 128) {
+
+ vl = 32;
+
+ // load Q3
+ vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
+
+ vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
+ vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
+ vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
+ vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
+
+ // compute mask for subtraction
+ vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
+ vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
+ vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl);
+ m <<= 1;
+
+ vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
+ vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
+ vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl);
+ m <<= 1;
+
+ vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
+ vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
+ vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl);
+ m <<= 1;
+
+ vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
+ vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
+ vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl);
+ m <<= 1;
+
+ // load Q8 and take product with Q3
+ vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
+ vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+ vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
+ vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
+
+ vl = 16;
+
+ // retrieve lane to multiply with scale
+ vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
+ vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
+ vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
+ vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
+ vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
+ vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
+ vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
+ vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
+
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
+
+ sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
+
+ q3 += 32; q8 += 128; scale += 8;
+
+ }
+
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+ sumf += d*sum_t;
+
+ }
+
+ *s = sumf;
+
+#elif defined(__POWER9_VECTOR__)
+ const vector signed char lowMask = vec_splats((signed char)0x3);
+ const vector signed char lowMask1 = vec_splats((int8_t)0xf);
+ const vector signed char lowMask2 = vec_splats((int8_t)0x30);
+ const vector int v0 = vec_splats((int32_t)0);
+ const vector signed char v1 = vec_splats((signed char)0x1);
+ const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+ const vector unsigned char v3 = vec_splats((unsigned char)0x3);
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+ const vector unsigned char v6 = vec_splats((unsigned char)0x6);
+ const vector signed char off = vec_splats((signed char)0x20);
+
+ vector float vsumf0 = vec_splats(0.0f);
+ vector float vsumf1 = vec_splats(0.0f);
+ vector float vsumf2 = vec_splats(0.0f);
+ vector float vsumf3 = vec_splats(0.0f);
+
+ for (int i = 0; i < nb; ++i) {
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+ vector float vyd = vec_splats(y[i].d);
+ vector float vd = vec_mul(vxd, vyd);
+
+ UNUSED(kmask1);
+ UNUSED(kmask2);
+
+ vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);
+ vector signed char u1 = vec_and(u0, lowMask1);
+ vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);
+ vector signed char u3 = (vector signed char)vec_mergeh((vector signed int)u2, (vector signed int)vec_sr(u2, v2));
+ vector signed char u30 = vec_sl(vec_and(u3, lowMask), v4);
+ vector signed char u31 = vec_and(u3, lowMask2);
+
+ u1 = vec_or(u1, u30);
+ u2 = vec_or(vec_sr(u0, v4), u31);
+
+ vector signed char vscales = (vector signed char)vec_mergeh((vector signed long long)u1, (vector signed long long)u2);
+ vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask);
+ vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask);
+
+ vscales = vec_sub(vscales, off);
+
+ vector signed int vsumi0 = v0;
+ vector signed int vsumi1 = v0;
+ vector signed int vsumi2 = v0;
+ vector signed int vsumi3 = v0;
+ vector signed int vsumi4 = v0;
+ vector signed int vsumi5 = v0;
+ vector signed int vsumi6 = v0;
+ vector signed int vsumi7 = v0;
+
+ const uint8_t * restrict q3 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ __builtin_prefetch(q3, 0, 1);
+ __builtin_prefetch(q8, 0, 1);
+
+ vector signed char qxs0 = (vector signed char)vec_xl( 0, q3);
+ vector signed char qxs1 = (vector signed char)vec_xl(16, q3);
+ q3 += 32;
+
+ //the low 2 bits
+ vector signed char qxs00 = vec_and(qxs0, lowMask);
+ vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask);
+ vector signed char qxs02 = vec_and(vec_sr(qxs0, v4), lowMask);
+ vector signed char qxs03 = vec_and(vec_sr(qxs0, v6), lowMask);
+ vector signed char qxs10 = vec_and(qxs1, lowMask);
+ vector signed char qxs11 = vec_and(vec_sr(qxs1, v2), lowMask);
+ vector signed char qxs12 = vec_and(vec_sr(qxs1, v4), lowMask);
+ vector signed char qxs13 = vec_and(vec_sr(qxs1, v6), lowMask);
+
+ //the 3rd bit
+ vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2);
+ vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, (vector unsigned char)v1)), v2);
+ vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2);
+ vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v3)), v2);
+ vector signed char qxh10 = vec_sl(vec_andc(v1, qxhs1), v2);
+ vector signed char qxh11 = vec_sl(vec_andc(v1, vec_sr(qxhs1, (vector unsigned char)v1)), v2);
+ vector signed char qxh12 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v2)), v2);
+ vector signed char qxh13 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v3)), v2);
+ qxhs0 = vec_sr(qxhs0, v4);
+ qxhs1 = vec_sr(qxhs1, v4);
+
+ vector signed char q3x00 = vec_sub(qxs00, qxh00);
+ vector signed char q3x01 = vec_sub(qxs01, qxh01);
+ vector signed char q3x02 = vec_sub(qxs02, qxh02);
+ vector signed char q3x03 = vec_sub(qxs03, qxh03);
+ vector signed char q3x10 = vec_sub(qxs10, qxh10);
+ vector signed char q3x11 = vec_sub(qxs11, qxh11);
+ vector signed char q3x12 = vec_sub(qxs12, qxh12);
+ vector signed char q3x13 = vec_sub(qxs13, qxh13);
+
+ vector signed char q8y00 = vec_xl( 0, q8);
+ vector signed char q8y10 = vec_xl( 16, q8);
+ vector signed char q8y01 = vec_xl( 32, q8);
+ vector signed char q8y11 = vec_xl( 48, q8);
+ vector signed char q8y02 = vec_xl( 64, q8);
+ vector signed char q8y12 = vec_xl( 80, q8);
+ vector signed char q8y03 = vec_xl( 96, q8);
+ vector signed char q8y13 = vec_xl(112, q8);
+ q8 += 128;
+
+ vector signed short vscales_h = vec_unpackh(vscales);
+ vector signed short vs0 = vec_splat(vscales_h, 0);
+ vector signed short vs1 = vec_splat(vscales_h, 1);
+ vector signed short vs2 = vec_splat(vscales_h, 2);
+ vector signed short vs3 = vec_splat(vscales_h, 3);
+ vector signed short vs4 = vec_splat(vscales_h, 4);
+ vector signed short vs5 = vec_splat(vscales_h, 5);
+ vector signed short vs6 = vec_splat(vscales_h, 6);
+ vector signed short vs7 = vec_splat(vscales_h, 7);
+ vscales = vec_sld(vscales, vscales, 8);
+
+ vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00));
+ vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01));
+ vector signed short qv02 = vec_add(vec_mule(q3x02, q8y02), vec_mulo(q3x02, q8y02));
+ vector signed short qv03 = vec_add(vec_mule(q3x03, q8y03), vec_mulo(q3x03, q8y03));
+ vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10));
+ vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11));
+ vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12));
+ vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13));
+
+ vsumi0 = vec_msum(qv00, vs0, vsumi0);
+ vsumi1 = vec_msum(qv01, vs2, vsumi1);
+ vsumi2 = vec_msum(qv02, vs4, vsumi2);
+ vsumi3 = vec_msum(qv03, vs6, vsumi3);
+ vsumi4 = vec_msum(qv10, vs1, vsumi4);
+ vsumi5 = vec_msum(qv11, vs3, vsumi5);
+ vsumi6 = vec_msum(qv12, vs5, vsumi6);
+ vsumi7 = vec_msum(qv13, vs7, vsumi7);
+ }
+
+ vsumi0 = vec_add(vsumi0, vsumi4);
+ vsumi1 = vec_add(vsumi1, vsumi5);
+ vsumi2 = vec_add(vsumi2, vsumi6);
+ vsumi3 = vec_add(vsumi3, vsumi7);
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+ vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+ vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+ }
+
+ vsumf0 = vec_add(vsumf0, vsumf2);
+ vsumf1 = vec_add(vsumf1, vsumf3);
+
+ vsumf0 = vec_add(vsumf0, vsumf1);
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ *s = vec_extract(vsumf0, 0);
+
+#elif defined __loongarch_asx
+
+ const __m256i m3 = __lasx_xvreplgr2vr_b(3);
+ const __m256i mone = __lasx_xvreplgr2vr_b(1);
+ const __m128i m32 = __lsx_vreplgr2vr_b(32);
+
+ __m256 acc = (__m256)__lasx_xvldi(0);
+
+ uint32_t aux[3];
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const uint8_t * restrict q3 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+ // Set up scales
+ memcpy(aux, x[i].scales, 12);
+ __m128i scales128 = lsx_set_w(
+ ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+ ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+ (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+ (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+ scales128 = __lsx_vsub_b(scales128, m32);
+ const __m256i all_scales = lasx_ext8_16(scales128);
+ const __m128i l_scales = lasx_extracti128(all_scales, 0);
+ const __m128i h_scales = lasx_extracti128(all_scales, 1);
+ const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
+
+ // high bit
+ const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
+
+ // integer accumulator
+ __m256i sumi = __lasx_xvldi(0);
+
+ int bit = 0;
+ int is = 0;
+ __m256i xvbit;
+
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ // load low 2 bits
+ const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
+
+ xvbit = __lasx_xvreplgr2vr_h(bit);
+ // prepare low and high bits
+ const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
+ const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
+ ++bit;
+
+ xvbit = __lasx_xvreplgr2vr_h(bit);
+ const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
+ const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
+ ++bit;
+
+ xvbit = __lasx_xvreplgr2vr_h(bit);
+ const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
+ const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
+ ++bit;
+
+ xvbit = __lasx_xvreplgr2vr_h(bit);
+ const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
+ const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
+ ++bit;
+
+ // load Q8 quants
+ const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h,
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+ // and 2 if the high bit was set)
+ __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0);
+ __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1);
+ __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2);
+ __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3);
+
+ __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0);
+ __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1);
+ __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2);
+ __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3);
+
+ p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
+ p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
+ p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
+ p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
+
+ // multiply with scales
+ p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
+ p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
+ p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
+ p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
+
+ // accumulate
+ p16_0 = __lasx_xvadd_w(p16_0, p16_1);
+ p16_2 = __lasx_xvadd_w(p16_2, p16_3);
+ sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
+ }
+ // multiply with block scale and accumulate
+ acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME
+ }
+
+ *s = hsum_float_8(acc);
+
+#else
+ // scalar version
+ // This function is written like this so the compiler can manage to vectorize most of it
+ // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
+ // manually vectorized version above. Every other version I tried would run at least 4 times slower.
+ // The ideal situation would be if we could just write the code once, and the compiler would
+ // automatically produce the best possible set of machine instructions, instead of us having to manually
+ // write vectorized versions for AVX, ARM_NEON, etc.
+
+ int8_t aux8[QK_K];
+ int16_t aux16[8];
+ float sums [8];
+ int32_t aux32[8];
+ memset(sums, 0, 8*sizeof(float));
+
+ uint32_t auxs[4];
+ const int8_t * scales = (const int8_t*)auxs;
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict hm = x[i].hmask;
+ const int8_t * restrict q8 = y[i].qs;
+ memset(aux32, 0, 8*sizeof(int32_t));
+ int8_t * restrict a = aux8;
+ uint8_t m = 1;
+ for (int j = 0; j < QK_K; j += 128) {
+ for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+ a += 32; m <<= 1;
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+ a += 32; m <<= 1;
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+ a += 32; m <<= 1;
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+ a += 32; m <<= 1;
+ q3 += 32;
+ }
+ a = aux8;
+
+ memcpy(auxs, x[i].scales, 12);
+ uint32_t tmp = auxs[2];
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+ for (int j = 0; j < QK_K/16; ++j) {
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
+ q8 += 8; a += 8;
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
+ q8 += 8; a += 8;
+ }
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+ }
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
+ *s = sumf;
+
+#endif
+
+}
+
+void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_q4_K * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+ static const uint32_t kmask1 = 0x3f3f3f3f;
+ static const uint32_t kmask2 = 0x0f0f0f0f;
+ static const uint32_t kmask3 = 0x03030303;
+
+ uint32_t utmp[4];
+
+#ifdef __ARM_NEON
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
+ const int32x4_t mzero = vdupq_n_s32(0);
+
+ ggml_int8x16x2_t q4bytes;
+ ggml_int8x16x2_t q8bytes;
+
+ float sumf = 0;
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+ memcpy(utmp, x[i].scales, 12);
+
+ uint32x2_t mins8 = { 0 };
+ mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
+ mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
+
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[0] &= kmask1;
+
+ const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
+ const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+ sumf -= dmin * vaddvq_s32(prod);
+
+ const uint8_t * scales = (const uint8_t *)utmp;
+
+ const uint8_t * restrict q4 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ int32_t sumi1 = 0;
+ int32_t sumi2 = 0;
+
+ for (int j = 0; j < QK_K/64; ++j) {
+ const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
+
+ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+ q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
+ q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
+
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+ sumi1 += vaddvq_s32(p1) * scales[2*j+0];
+
+ q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+ q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+ q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
+
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+
+ sumi2 += vaddvq_s32(p2) * scales[2*j+1];
+ }
+
+ sumf += d * (sumi1 + sumi2);
+
+ }
+
+ *s = sumf;
+
+#elif defined __AVX2__
+
+ const __m256i m4 = _mm256_set1_epi8(0xF);
+
+ __m256 acc = _mm256_setzero_ps();
+ __m128 acc_m = _mm_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ const uint8_t * restrict q4 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+ const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
+ acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
+
+ const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
+ const __m256i scales = MM256_SET_M128I(sc128, sc128);
+
+ __m256i sumi = _mm256_setzero_si256();
+
+ for (int j = 0; j < QK_K/64; ++j) {
+
+ const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
+ const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
+
+ const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+ const __m256i q4l = _mm256_and_si256(q4bits, m4);
+ const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
+
+ const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
+ p16l = _mm256_madd_epi16(scale_l, p16l);
+
+ const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
+ p16h = _mm256_madd_epi16(scale_h, p16h);
+ const __m256i sumj = _mm256_add_epi32(p16l, p16h);
+
+ sumi = _mm256_add_epi32(sumi, sumj);
+ }
+
+ __m256 vd = _mm256_set1_ps(d);
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
+
+ }
+
+ acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
+ acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
+
+ *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
+
+#elif defined __AVX__
+
+ const __m128i m4 = _mm_set1_epi8(0xF);
+ const __m128i m2 = _mm_set1_epi8(0x2);
+
+ __m256 acc = _mm256_setzero_ps();
+ __m128 acc_m = _mm_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ const uint8_t * restrict q4 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
+ const __m128i scales = _mm_cvtepu8_epi16(utmps);
+ const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
+
+ const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
+ const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
+ const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
+ const __m128i prod = _mm_madd_epi16(mins, q8s);
+ acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
+
+ __m128i sumi_0 = _mm_setzero_si128();
+ __m128i sumi_1 = _mm_setzero_si128();
+
+ __m128i shuffle = _mm_set1_epi16(0x0100);
+ for (int j = 0; j < QK_K/64; ++j) {
+
+ const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi16(shuffle, m2);
+
+ __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+ const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
+ const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
+ q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+ const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
+ const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
+
+ const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
+ p16l = _mm_madd_epi16(scale_l, p16l);
+ sumi_0 = _mm_add_epi32(sumi_0, p16l);
+ const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
+ p16l = _mm_madd_epi16(scale_l, p16l);
+ sumi_1 = _mm_add_epi32(sumi_1, p16l);
+
+ const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
+ p16h = _mm_madd_epi16(scale_h, p16h);
+ sumi_0 = _mm_add_epi32(sumi_0, p16h);
+ const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
+ p16h = _mm_madd_epi16(scale_h, p16h);
+ sumi_1 = _mm_add_epi32(sumi_1, p16h);
+
+ }
+
+ __m256 vd = _mm256_set1_ps(d);
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+ acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
+
+ }
+
+ acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
+ acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
+
+ *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
+
+#elif defined __riscv_v_intrinsic
+
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
+
+ float sumf = 0;
+
+ for (int i = 0; i < nb; ++i) {
+
+ size_t vl = 8;
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
+ vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
+ vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
+
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
+ vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
+ vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
+
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+ sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
+
+ const uint8_t * restrict q4 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ vl = 32;
+
+ int32_t sum_1 = 0;
+ int32_t sum_2 = 0;
+
+ vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
+
+ for (int j = 0; j < QK_K/64; ++j) {
+ // load Q4
+ vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
+
+ // load Q8 and multiply it with lower Q4 nibble
+ vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
+ vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
+ vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
+ vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
+
+ sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
+
+ // load Q8 and multiply it with upper Q4 nibble
+ vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
+ vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
+ vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
+ vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
+
+ sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
+
+ q4 += 32; q8 += 64;
+
+ }
+
+ sumf += d*(sum_1 + sum_2);
+
+ }
+
+ *s = sumf;
+
+#elif defined(__POWER9_VECTOR__)
+ const vector signed char lowMask = vec_splats((signed char)0xF);
+ const vector signed char lowMask1 = vec_splats((int8_t)0x3f);
+ const vector signed char lowMask2 = vec_splats((int8_t)0x30);
+ const vector int v0 = vec_splats((int32_t)0);
+ const vector unsigned char v2 = vec_splats((uint8_t)2);
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+ vector float vsumf0 = vec_splats(0.0f);
+ vector float vsumf1 = vec_splats(0.0f);
+ vector float vsumf2 = vec_splats(0.0f);
+ vector float vsumf3 = vec_splats(0.0f);
+
+ for (int i = 0; i < nb; ++i) {
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+ vector float vyd = vec_splats(y[i].d);
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin));
+ vector float vdmin = vec_mul(vxmin, vyd);
+
+ vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
+ vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
+
+ UNUSED(kmask1);
+ UNUSED(kmask2);
+ UNUSED(kmask3);
+ UNUSED(utmp);
+
+ vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);
+ vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2);
+ vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);
+ vector signed char u3 = vec_sr(u2, v4);
+
+ vector signed char u30 = u1;
+ vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3);
+
+ u1 = vec_and(u0, lowMask1);
+ u2 = vec_or(u30, u31);
+
+ vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2);
+
+ vector signed short vscales = vec_unpackh(utmps);
+ vector signed short q4xmins = vec_unpackl(utmps);
+ vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins);
+ vector signed short q4xmins1 = vec_mergel(q4xmins, q4xmins);
+
+ vector signed int prod0 = vec_mule(q4xmins0, q8ysums0);
+ vector signed int prod1 = vec_mule(q4xmins1, q8ysums1);
+ vector signed int prod2 = vec_mulo(q4xmins0, q8ysums0);
+ vector signed int prod3 = vec_mulo(q4xmins1, q8ysums1);
+
+ vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
+ vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
+ vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
+ vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
+
+ vector signed int vsumi0 = v0;
+ vector signed int vsumi1 = v0;
+ vector signed int vsumi2 = v0;
+ vector signed int vsumi3 = v0;
+
+ const uint8_t * restrict q4 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ for (int j = 0; j < QK_K/64; j+=2) {
+ __builtin_prefetch(q4, 0, 1);
+ __builtin_prefetch(q8, 0, 1);
+
+ vector signed char qxs0 = (vector signed char)vec_xl( 0, q4);
+ vector signed char qxs1 = (vector signed char)vec_xl(16, q4);
+ vector signed char qxs2 = (vector signed char)vec_xl(32, q4);
+ vector signed char qxs3 = (vector signed char)vec_xl(48, q4);
+ q4 += 64;
+
+ vector unsigned char q4x00 = (vector unsigned char)vec_and(qxs0, lowMask);
+ vector unsigned char q4x01 = (vector unsigned char)vec_sr(qxs0, v4);
+ vector unsigned char q4x10 = (vector unsigned char)vec_and(qxs1, lowMask);
+ vector unsigned char q4x11 = (vector unsigned char)vec_sr(qxs1, v4);
+ vector unsigned char q4x20 = (vector unsigned char)vec_and(qxs2, lowMask);
+ vector unsigned char q4x21 = (vector unsigned char)vec_sr(qxs2, v4);
+ vector unsigned char q4x30 = (vector unsigned char)vec_and(qxs3, lowMask);
+ vector unsigned char q4x31 = (vector unsigned char)vec_sr(qxs3, v4);
+
+ vector signed char q8y00 = vec_xl( 0, q8);
+ vector signed char q8y10 = vec_xl( 16, q8);
+ vector signed char q8y01 = vec_xl( 32, q8);
+ vector signed char q8y11 = vec_xl( 48, q8);
+ vector signed char q8y20 = vec_xl( 64, q8);
+ vector signed char q8y30 = vec_xl( 80, q8);
+ vector signed char q8y21 = vec_xl( 96, q8);
+ vector signed char q8y31 = vec_xl(112, q8);
+ q8 += 128;
+
+ vector signed int qv00 = vec_msum(q8y00, q4x00, v0);
+ vector signed int qv01 = vec_msum(q8y01, q4x01, v0);
+ vector signed int qv10 = vec_msum(q8y10, q4x10, v0);
+ vector signed int qv11 = vec_msum(q8y11, q4x11, v0);
+ vector signed int qv20 = vec_msum(q8y20, q4x20, v0);
+ vector signed int qv21 = vec_msum(q8y21, q4x21, v0);
+ vector signed int qv30 = vec_msum(q8y30, q4x30, v0);
+ vector signed int qv31 = vec_msum(q8y31, q4x31, v0);
+
+ vector signed int vscales_h = vec_unpackh(vscales);
+ vector signed int vs0 = vec_splat(vscales_h, 0);
+ vector signed int vs1 = vec_splat(vscales_h, 1);
+ vector signed int vs2 = vec_splat(vscales_h, 2);
+ vector signed int vs3 = vec_splat(vscales_h, 3);
+ vscales = vec_sld(vscales, vscales, 8);
+
+ vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0);
+ vsumi1 = vec_add(vec_mul(qv01, vs1), vsumi1);
+ vsumi2 = vec_add(vec_mul(qv20, vs2), vsumi2);
+ vsumi3 = vec_add(vec_mul(qv21, vs3), vsumi3);
+
+ vsumi0 = vec_add(vec_mul(qv10, vs0), vsumi0);
+ vsumi1 = vec_add(vec_mul(qv11, vs1), vsumi1);
+ vsumi2 = vec_add(vec_mul(qv30, vs2), vsumi2);
+ vsumi3 = vec_add(vec_mul(qv31, vs3), vsumi3);
+ }
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+ vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+ vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+ }
+
+ vsumf0 = vec_add(vsumf0, vsumf2);
+ vsumf1 = vec_add(vsumf1, vsumf3);
+
+ vsumf0 = vec_add(vsumf0, vsumf1);
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ *s = vec_extract(vsumf0, 0);
+
+#elif defined __loongarch_asx
+ GGML_UNUSED(kmask1);
+ GGML_UNUSED(kmask2);
+ GGML_UNUSED(kmask3);
+
+ const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
+
+ __m256 acc = (__m256)__lasx_xvldi(0);
+ __m128 acc_m = (__m128)__lsx_vldi(0);
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ const uint8_t * restrict q4 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+ const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
+ const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
+ const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
+ acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
+
+ const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
+ const __m256i scales = lasx_insertf128(sc128, sc128);
+
+ __m256i sumi = __lasx_xvldi(0);
+
+ for (int j = 0; j < QK_K/64; ++j) {
+
+ const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
+ const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
+
+ const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
+ const __m256i q4l = __lasx_xvand_v(q4bits, m4);
+ const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4);
+
+ const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ __m256i p16l = lasx_maddubs_h(q4l, q8l);
+ p16l = lasx_madd_h(scale_l, p16l);
+
+ const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ __m256i p16h = lasx_maddubs_h(q4h, q8h);
+ p16h = lasx_madd_h(scale_h, p16h);
+ const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
+
+ sumi = __lasx_xvadd_w(sumi, sumj);
+ }
+
+ __m256 vd = __lasx_xvreplfr2vr_s(d);
+ acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
+
+ }
+
+ acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));
+ __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);
+ acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
+
+
+ ft_union fi;
+ fi.i = __lsx_vpickve2gr_w(acc_m, 0);
+ *s = hsum_float_8(acc) + fi.f ;
+#else
+
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
+
+ int8_t aux8[QK_K];
+ int16_t aux16[8];
+ float sums [8];
+ int32_t aux32[8];
+ memset(sums, 0, 8*sizeof(float));
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+ const uint8_t * restrict q4 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+ memset(aux32, 0, 8*sizeof(int32_t));
+ int8_t * restrict a = aux8;
+ for (int j = 0; j < QK_K/64; ++j) {
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
+ a += 32;
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
+ a += 32; q4 += 32;
+ }
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ int sumi = 0;
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
+ a = aux8;
+ int is = 0;
+ for (int j = 0; j < QK_K/32; ++j) {
+ int32_t scale = scales[is++];
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ }
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+ sumf -= dmin * sumi;
+ }
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
+ *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_q5_K * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+ static const uint32_t kmask1 = 0x3f3f3f3f;
+ static const uint32_t kmask2 = 0x0f0f0f0f;
+ static const uint32_t kmask3 = 0x03030303;
+
+ uint32_t utmp[4];
+
+#ifdef __ARM_NEON
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
+ const uint8x16_t mone = vdupq_n_u8(1);
+ const uint8x16_t mtwo = vdupq_n_u8(2);
+ const int32x4_t mzero = vdupq_n_s32(0);
+
+ ggml_int8x16x4_t q5bytes;
+
+ float sumf = 0;
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
+ const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
+ const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+ int32_t sumi_mins = vaddvq_s32(prod);
+
+ const uint8_t * scales = (const uint8_t *)utmp;
+
+ const uint8_t * restrict q5 = x[i].qs;
+ const uint8_t * restrict qh = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
+
+ ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
+
+ ggml_uint8x16x4_t q5h;
+
+ int32_t sumi = 0;
+
+ for (int j = 0; j < QK_K/64; ++j) {
+
+ const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
+ const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+ q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
+ q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
+ q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
+ q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
+ qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
+ qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
+
+ q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
+ q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
+ q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
+ q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
+
+ sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
+ sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
+ }
+
+ sumf += d * sumi - dmin * sumi_mins;
+ }
+
+ *s = sumf;
+
+#elif defined __AVX2__
+
+ const __m256i m4 = _mm256_set1_epi8(0xF);
+ const __m128i mzero = _mm_setzero_si128();
+ const __m256i mone = _mm256_set1_epi8(1);
+
+ __m256 acc = _mm256_setzero_ps();
+
+ float summs = 0.f;
+
+ for (int i = 0; i < nb; ++i) {
+ const uint8_t * restrict q5 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
+ const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+ const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
+ const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
+ summs += dmin * _mm_extract_epi32(hsum, 0);
+
+ const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
+ const __m256i scales = MM256_SET_M128I(sc128, sc128);
+
+ const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
+ __m256i hmask = mone;
+
+ __m256i sumi = _mm256_setzero_si256();
+
+ int bit = 0;
+
+ for (int j = 0; j < QK_K/64; ++j) {
+
+ const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
+ const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
+
+ const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
+
+ const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
+ const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
+ const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
+ hmask = _mm256_slli_epi16(hmask, 1);
+
+ const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
+ const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
+ const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
+ hmask = _mm256_slli_epi16(hmask, 1);
+
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+ __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
+ __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
+
+ p16_0 = _mm256_madd_epi16(scale_0, p16_0);
+ p16_1 = _mm256_madd_epi16(scale_1, p16_1);
+
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+
+ }
+
+ __m256 vd = _mm256_set1_ps(d);
+ acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
+
+ }
+
+ *s = hsum_float_8(acc) + summs;
+
+#elif defined __AVX__
+
+ const __m128i m4 = _mm_set1_epi8(0xF);
+ const __m128i mzero = _mm_setzero_si128();
+ const __m128i mone = _mm_set1_epi8(1);
+ const __m128i m2 = _mm_set1_epi8(2);
+
+ __m256 acc = _mm256_setzero_ps();
+
+ float summs = 0.f;
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ const uint8_t * restrict q5 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
+ const __m128i scales = _mm_cvtepu8_epi16(utmps);
+ const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
+
+ const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
+ const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
+ const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
+ const __m128i prod = _mm_madd_epi16(mins, q8s);
+ const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
+ summs += dmin * _mm_extract_epi32(hsum, 0);
+
+ const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
+ const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
+ __m128i hmask = mone;
+
+ __m128i sumi_0 = _mm_setzero_si128();
+ __m128i sumi_1 = _mm_setzero_si128();
+
+ int bit = 0;
+
+ __m128i shuffle = _mm_set1_epi16(0x0100);
+ for (int j = 0; j < QK_K/64; ++j) {
+
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi16(shuffle, m2);
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi16(shuffle, m2);
+
+ const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
+ const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
+
+ __m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
+ __m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
+ __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
+ __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
+ __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0);
+ __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1);
+ hmask = _mm_slli_epi16(hmask, 1);
+
+ __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
+ __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
+ p16_0 = _mm_madd_epi16(scale_0, p16_0);
+ p16_1 = _mm_madd_epi16(scale_0, p16_1);
+
+ q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
+ q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
+ q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
+ q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
+ q5_0 = _mm_add_epi8(q5l_0, q5h_0);
+ q5_1 = _mm_add_epi8(q5l_1, q5h_1);
+ hmask = _mm_slli_epi16(hmask, 1);
+
+ q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
+ __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
+ p16_2 = _mm_madd_epi16(scale_1, p16_2);
+ p16_3 = _mm_madd_epi16(scale_1, p16_3);
+
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+
+ }
+
+ __m256 vd = _mm256_set1_ps(d);
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+ acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
+
+ }
+
+ *s = hsum_float_8(acc) + summs;
+
+#elif defined __riscv_v_intrinsic
+
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
+
+ float sumf = 0;
+ float sums = 0.0;
+
+ size_t vl;
+
+ for (int i = 0; i < nb; ++i) {
+
+ vl = 8;
+
+ const uint8_t * restrict q5 = x[i].qs;
+ const uint8_t * restrict hm = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+
+ vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
+ vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
+ vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
+
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
+ vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
+ vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
+
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+ sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
+
+ vl = 32;
+ int32_t aux32 = 0;
+ int is = 0;
+
+ uint8_t m = 1;
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+ vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl);
+
+ for (int j = 0; j < QK_K/64; ++j) {
+ // load Q5 and Q8
+ vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl);
+ vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl);
+ vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl);
+
+ // compute mask for addition
+ vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
+ vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
+ vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
+ vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl);
+ m <<= 1;
+
+ vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
+ vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
+ vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
+ vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl);
+ m <<= 1;
+
+ vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
+ vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl);
+
+ vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl);
+ vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl);
+
+ vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl);
+ vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl);
+
+ aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2);
+ q5 += 32; q8 += 64;
+
+ }
+
+ vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1);
+ sums += __riscv_vfmv_f_s_f32m1_f32(vaux);
+
+ }
+
+ *s = sumf+sums;
+
+#elif defined(__POWER9_VECTOR__)
+ const vector signed char lowMask = vec_splats((signed char)0xF);
+ const vector signed char lowMask1 = vec_splats((int8_t)0x3f);
+ const vector signed char lowMask2 = vec_splats((int8_t)0x30);
+ const vector int v0 = vec_splats((int32_t)0);
+ const vector unsigned char v1 = vec_splats((unsigned char)0x1);
+ const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+ const vector unsigned char v3 = vec_splats((unsigned char)0x3);
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+ vector float vsumf0 = vec_splats(0.0f);
+ vector float vsumf1 = vec_splats(0.0f);
+ vector float vsumf2 = vec_splats(0.0f);
+ vector float vsumf3 = vec_splats(0.0f);
+
+ for (int i = 0; i < nb; ++i) {
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+ vector float vyd = vec_splats(y[i].d);
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin));
+ vector float vdmin = vec_mul(vxmin, vyd);
+
+ UNUSED(kmask1);
+ UNUSED(kmask2);
+ UNUSED(kmask3);
+ UNUSED(utmp);
+
+ vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);
+ vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2);
+ vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);
+ vector signed char u3 = vec_sr(u2, v4);
+
+ vector signed char u30 = u1;
+ vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3);
+
+ u1 = vec_and(u0, lowMask1);
+ u2 = vec_or(u30, u31);
+
+ vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2);
+
+ vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
+ vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
+
+ vector signed short vscales = vec_unpackh(utmps);
+
+ vector signed short q5xmins = vec_unpackl(utmps);
+ vector signed short q5xmins0 = vec_mergeh(q5xmins, q5xmins);
+ vector signed short q5xmins1 = vec_mergel(q5xmins, q5xmins);
+
+ vector signed int prod0 = vec_mule(q5xmins0, q8ysums0);
+ vector signed int prod1 = vec_mule(q5xmins1, q8ysums1);
+ vector signed int prod2 = vec_mulo(q5xmins0, q8ysums0);
+ vector signed int prod3 = vec_mulo(q5xmins1, q8ysums1);
+
+ vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
+ vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
+ vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
+ vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
+
+ vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh);
+ vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh);
+
+ vector signed int vsumi0 = v0;
+ vector signed int vsumi1 = v0;
+ vector signed int vsumi2 = v0;
+ vector signed int vsumi3 = v0;
+
+ const uint8_t * restrict q5 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ for (int j = 0; j < QK_K/64; ++j) {
+ __builtin_prefetch(q5, 0, 1);
+ __builtin_prefetch(q8, 0, 1);
+
+ vector signed char qxs0 = (vector signed char)vec_xl( 0, q5);
+ vector signed char qxs1 = (vector signed char)vec_xl(16, q5);
+ q5 += 32;
+
+ vector signed char qxs00 = vec_and(qxs0, lowMask);
+ vector signed char qxs01 = vec_sr(qxs0, v4);
+ vector signed char qxs10 = vec_and(qxs1, lowMask);
+ vector signed char qxs11 = vec_sr(qxs1, v4);
+
+ vector signed char q5h00 = vec_sl(vec_and((vector signed char)v1, qxhs0), v4);
+ vector signed char q5h01 = vec_sl(vec_and((vector signed char)v2, qxhs0), v3);
+ vector signed char q5h10 = vec_sl(vec_and((vector signed char)v1, qxhs1), v4);
+ vector signed char q5h11 = vec_sl(vec_and((vector signed char)v2, qxhs1), v3);
+ qxhs0 = vec_sr(qxhs0, v2);
+ qxhs1 = vec_sr(qxhs1, v2);
+
+ vector unsigned char q5x00 = (vector unsigned char)vec_or(q5h00, qxs00);
+ vector unsigned char q5x01 = (vector unsigned char)vec_or(q5h01, qxs01);
+ vector unsigned char q5x10 = (vector unsigned char)vec_or(q5h10, qxs10);
+ vector unsigned char q5x11 = (vector unsigned char)vec_or(q5h11, qxs11);
+
+ vector signed char q8y00 = vec_xl( 0, q8);
+ vector signed char q8y10 = vec_xl(16, q8);
+ vector signed char q8y01 = vec_xl(32, q8);
+ vector signed char q8y11 = vec_xl(48, q8);
+ q8 += 64;
+
+ vector signed int qv00 = vec_msum(q8y00, q5x00, v0);
+ vector signed int qv01 = vec_msum(q8y01, q5x01, v0);
+ vector signed int qv10 = vec_msum(q8y10, q5x10, v0);
+ vector signed int qv11 = vec_msum(q8y11, q5x11, v0);
+
+ vector signed int vscales_h = vec_unpackh(vscales);
+ vector signed int vs0 = vec_splat(vscales_h, 0);
+ vector signed int vs1 = vec_splat(vscales_h, 1);
+ vscales = vec_sld(vscales, vscales, 12);
+
+ vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0);
+ vsumi1 = vec_add(vec_mul(qv10, vs0), vsumi1);
+ vsumi2 = vec_add(vec_mul(qv01, vs1), vsumi2);
+ vsumi3 = vec_add(vec_mul(qv11, vs1), vsumi3);
+ }
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+ vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+ vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+ }
+
+ vsumf0 = vec_add(vsumf0, vsumf2);
+ vsumf1 = vec_add(vsumf1, vsumf3);
+
+ vsumf0 = vec_add(vsumf0, vsumf1);
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ *s = vec_extract(vsumf0, 0);
+
+#elif defined __loongarch_asx
+ GGML_UNUSED(kmask1);
+ GGML_UNUSED(kmask2);
+ GGML_UNUSED(kmask3);
+
+ const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
+ const __m128i mzero = __lsx_vldi(0);
+ const __m256i mone = __lasx_xvreplgr2vr_b(1);
+
+ __m256 acc = (__m256)__lasx_xvldi(0);
+
+ float summs = 0.f;
+
+ for (int i = 0; i < nb; ++i) {
+
+ const uint8_t * restrict q5 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+ const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
+ const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
+ const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
+ const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero);
+ summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check
+
+ const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
+ const __m256i scales = lasx_insertf128(sc128, sc128);
+
+ const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
+ __m256i hmask = mone;
+
+ __m256i sumi = __lasx_xvldi(0);
+
+ int bit = 0;
+ __m256i xvbit;
+
+ for (int j = 0; j < QK_K/64; ++j) {
+
+ const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
+ const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
+
+ const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
+
+ xvbit = __lasx_xvreplgr2vr_h(bit++);
+ const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
+ const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
+ const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0);
+ hmask = __lasx_xvslli_h(hmask, 1);
+
+ xvbit = __lasx_xvreplgr2vr_h(bit++);
+ const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
+ const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
+ const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
+ hmask = __lasx_xvslli_h(hmask, 1);
+
+ const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+
+ __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0);
+ __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1);
+
+ p16_0 = lasx_madd_h(scale_0, p16_0);
+ p16_1 = lasx_madd_h(scale_1, p16_1);
+
+ sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
+
+ }
+
+ __m256 vd = __lasx_xvreplfr2vr_s(d);
+ acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
+
+ }
+
+ *s = hsum_float_8(acc) + summs;
+
+#else
+
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
+
+ int8_t aux8[QK_K];
+ int16_t aux16[8];
+ float sums [8];
+ int32_t aux32[8];
+ memset(sums, 0, 8*sizeof(float));
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+ const uint8_t * restrict q4 = x[i].qs;
+ const uint8_t * restrict hm = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
+ memset(aux32, 0, 8*sizeof(int32_t));
+ int8_t * restrict a = aux8;
+ uint8_t m = 1;
+ for (int j = 0; j < QK_K/64; ++j) {
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
+ a += 32; m <<= 1;
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
+ a += 32; m <<= 1;
+ q4 += 32;
+ }
+ memcpy(utmp, x[i].scales, 12);
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+ const uint32_t uaux = utmp[1] & kmask1;
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+ utmp[2] = uaux;
+ utmp[0] &= kmask1;
+
+ int sumi = 0;
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
+ a = aux8;
+ int is = 0;
+ for (int j = 0; j < QK_K/32; ++j) {
+ int32_t scale = scales[is++];
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ }
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+ sumf -= dmin * sumi;
+ }
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
+ *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_q6_K * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+ float sum = 0;
+
+ const uint8x16_t m4b = vdupq_n_u8(0xF);
+ const int32x4_t vzero = vdupq_n_s32(0);
+ //const int8x16_t m32s = vdupq_n_s8(32);
+
+ const uint8x16_t mone = vdupq_n_u8(3);
+
+ ggml_int8x16x4_t q6bytes;
+ ggml_uint8x16x4_t q6h;
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
+
+ const uint8_t * restrict q6 = x[i].ql;
+ const uint8_t * restrict qh = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const int8_t * restrict scale = x[i].scales;
+
+ const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
+ const int8x16_t scales = vld1q_s8(scale);
+ const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
+
+ const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
+ vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
+ vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
+ int32_t isum_mins = vaddvq_s32(prod);
+
+ int32_t isum = 0;
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
+ ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
+ ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+ q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
+ q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
+ uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
+ q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+ shifted = vshrq_n_u8(qhbits.val[1], 2);
+ q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+
+ //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
+ //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
+ //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
+ //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
+ q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
+ q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
+ q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
+ q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
+
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+
+ scale += 4;
+
+ q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+ shifted = vshrq_n_u8(qhbits.val[0], 4);
+ q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+ shifted = vshrq_n_u8(qhbits.val[1], 4);
+ q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+ shifted = vshrq_n_u8(qhbits.val[0], 6);
+ q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+ shifted = vshrq_n_u8(qhbits.val[1], 6);
+ q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+
+ //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
+ //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
+ //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
+ //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
+ q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
+ q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
+ q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
+ q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
+
+ isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+ vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+ scale += 4;
+ }
+ //sum += isum * d_all * y[i].d;
+ sum += d_all * y[i].d * (isum - 32 * isum_mins);
+
+ }
+ *s = sum;
+
+#elif defined __AVX2__
+
+ const __m256i m4 = _mm256_set1_epi8(0xF);
+ const __m256i m2 = _mm256_set1_epi8(3);
+ const __m256i m32s = _mm256_set1_epi8(32);
+
+ __m256 acc = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+ const uint8_t * restrict q4 = x[i].ql;
+ const uint8_t * restrict qh = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+
+ __m256i sumi = _mm256_setzero_si256();
+
+ int is = 0;
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
+ const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
+ const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
+ is += 4;
+
+ const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+ const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+ const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
+
+ const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
+ const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
+ const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
+ const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
+
+ const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
+ const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
+ const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
+ const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
+
+ const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+ __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
+ __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
+ __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
+ __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
+
+ __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
+ __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
+ __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
+ __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
+
+ p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+ p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+ p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
+ p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
+
+ p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
+ p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
+ p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
+ p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
+
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
+
+ }
+
+ acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+ }
+
+ *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+ const __m128i m4 = _mm_set1_epi8(0xF);
+ const __m128i m3 = _mm_set1_epi8(3);
+ const __m128i m32s = _mm_set1_epi8(32);
+ const __m128i m2 = _mm_set1_epi8(2);
+
+ __m256 acc = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+ const uint8_t * restrict q4 = x[i].ql;
+ const uint8_t * restrict qh = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+
+ __m128i sumi_0 = _mm_setzero_si128();
+ __m128i sumi_1 = _mm_setzero_si128();
+
+ __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
+ const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
+
+ const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
+ const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
+ const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4);
+ const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4);
+ const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4);
+ const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4);
+ const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4);
+ const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4);
+
+ const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+ const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+ const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+ const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+
+ const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0);
+ const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1);
+ const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2);
+ const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3);
+ const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4);
+ const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5);
+ const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6);
+ const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7);
+
+ const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+ const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+ __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
+ __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
+ __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
+ __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
+ __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
+ __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
+ __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
+ __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
+
+ __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
+ __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
+ __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
+ __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
+ __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
+ __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
+ __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
+ __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
+
+ p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+ p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+ p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+ p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+ p16_4 = _mm_sub_epi16(p16_4, q8s_4);
+ p16_5 = _mm_sub_epi16(p16_5, q8s_5);
+ p16_6 = _mm_sub_epi16(p16_6, q8s_6);
+ p16_7 = _mm_sub_epi16(p16_7, q8s_7);
+
+ const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi8(shuffle, m2);
+ const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi8(shuffle, m2);
+ const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi8(shuffle, m2);
+ const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
+ shuffle = _mm_add_epi8(shuffle, m2);
+
+ p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
+ p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
+ p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
+ p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
+ p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
+ p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5);
+ p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
+ p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7);
+
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+ sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
+ sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
+
+ }
+
+ __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+ acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
+ }
+
+ *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+ const uint8_t * restrict q6 = x[i].ql;
+ const uint8_t * restrict qh = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const int8_t * restrict scale = x[i].scales;
+
+ size_t vl;
+
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+
+ int sum_t = 0;
+ int is = 0;
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ vl = 32;
+
+ // load qh
+ vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
+
+ // load Q6
+ vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
+ vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
+
+ vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
+ vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
+ vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
+ vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
+
+ vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
+ vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
+ vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
+ vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
+
+ vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
+ vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
+ vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
+ vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
+
+ vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
+ vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
+ vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
+ vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
+
+ // load Q8 and take product
+ vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
+ vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+ vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
+ vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
+
+ vl = 16;
+
+ vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
+ vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
+ vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
+ vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
+ vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
+ vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
+ vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
+ vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
+
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
+
+ sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
+
+ q6 += 64; qh += 32; q8 += 128; is=8;
+
+ }
+
+ sumf += d * sum_t;
+
+ }
+
+ *s = sumf;
+
+#elif defined(__POWER9_VECTOR__)
+ const vector signed char lowMask = vec_splats((signed char)0xF);
+ const vector int v0 = vec_splats((int32_t)0);
+ const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+ const vector unsigned char v3 = vec_splats((unsigned char)0x3);
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+ const vector unsigned char v6 = vec_splats((unsigned char)0x6);
+ const vector signed char off = vec_splats((signed char)0x20);
+
+ vector float vsumf0 = vec_splats(0.0f);
+ vector float vsumf1 = vec_splats(0.0f);
+ vector float vsumf2 = vec_splats(0.0f);
+ vector float vsumf3 = vec_splats(0.0f);
+
+ for (int i = 0; i < nb; ++i) {
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+ vector float vyd = vec_splats(y[i].d);
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector signed int vsumi0 = v0;
+ vector signed int vsumi1 = v0;
+ vector signed int vsumi2 = v0;
+ vector signed int vsumi3 = v0;
+ vector signed int vsumi4 = v0;
+ vector signed int vsumi5 = v0;
+ vector signed int vsumi6 = v0;
+ vector signed int vsumi7 = v0;
+
+ const uint8_t * restrict q6 = x[i].ql;
+ const uint8_t * restrict qh = x[i].qh;
+ const int8_t * restrict qs = x[i].scales;
+ const int8_t * restrict q8 = y[i].qs;
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ __builtin_prefetch(q6, 0, 0);
+ __builtin_prefetch(qh, 0, 0);
+ __builtin_prefetch(q8, 0, 0);
+
+ vector signed char qxs0 = (vector signed char)vec_xl( 0, q6);
+ vector signed char qxs1 = (vector signed char)vec_xl(16, q6);
+ vector signed char qxs2 = (vector signed char)vec_xl(32, q6);
+ vector signed char qxs3 = (vector signed char)vec_xl(48, q6);
+ q6 += 64;
+
+ vector signed char qxs00 = vec_and(qxs0, lowMask);
+ vector signed char qxs01 = vec_sr(qxs0, v4);
+ vector signed char qxs10 = vec_and(qxs1, lowMask);
+ vector signed char qxs11 = vec_sr(qxs1, v4);
+ vector signed char qxs20 = vec_and(qxs2, lowMask);
+ vector signed char qxs21 = vec_sr(qxs2, v4);
+ vector signed char qxs30 = vec_and(qxs3, lowMask);
+ vector signed char qxs31 = vec_sr(qxs3, v4);
+
+ vector signed char qxhs0 = (vector signed char)vec_xl( 0, qh);
+ vector signed char qxhs1 = (vector signed char)vec_xl(16, qh);
+ qh += 32;
+
+ vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4);
+ vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4);
+ vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, qxhs1), v4);
+ vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v4)), v4);
+ vector signed char qxh20 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4);
+ vector signed char qxh21 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4);
+ vector signed char qxh30 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v2)), v4);
+ vector signed char qxh31 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v6)), v4);
+
+ vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off);
+ vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off);
+ vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off);
+ vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off);
+ vector signed char q6x20 = vec_sub(vec_or(qxh20, qxs20), off);
+ vector signed char q6x21 = vec_sub(vec_or(qxh21, qxs21), off);
+ vector signed char q6x30 = vec_sub(vec_or(qxh30, qxs30), off);
+ vector signed char q6x31 = vec_sub(vec_or(qxh31, qxs31), off);
+
+ vector signed char q8y00 = vec_xl( 0, q8);
+ vector signed char q8y10 = vec_xl( 16, q8);
+ vector signed char q8y20 = vec_xl( 32, q8);
+ vector signed char q8y30 = vec_xl( 48, q8);
+ vector signed char q8y01 = vec_xl( 64, q8);
+ vector signed char q8y11 = vec_xl( 80, q8);
+ vector signed char q8y21 = vec_xl( 96, q8);
+ vector signed char q8y31 = vec_xl(112, q8);
+ q8 += 128;
+
+ vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00));
+ vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10));
+ vector signed short qv20 = vec_add(vec_mule(q6x20, q8y20), vec_mulo(q6x20, q8y20));
+ vector signed short qv30 = vec_add(vec_mule(q6x30, q8y30), vec_mulo(q6x30, q8y30));
+ vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01));
+ vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11));
+ vector signed short qv21 = vec_add(vec_mule(q6x21, q8y21), vec_mulo(q6x21, q8y21));
+ vector signed short qv31 = vec_add(vec_mule(q6x31, q8y31), vec_mulo(q6x31, q8y31));
+
+ vector signed short vscales = vec_unpackh(vec_xl_len(qs, 8));
+ qs += 8;
+
+ vector signed short vs0 = vec_splat(vscales, 0);
+ vector signed short vs1 = vec_splat(vscales, 1);
+ vector signed short vs2 = vec_splat(vscales, 2);
+ vector signed short vs3 = vec_splat(vscales, 3);
+ vector signed short vs4 = vec_splat(vscales, 4);
+ vector signed short vs5 = vec_splat(vscales, 5);
+ vector signed short vs6 = vec_splat(vscales, 6);
+ vector signed short vs7 = vec_splat(vscales, 7);
+
+ vsumi0 = vec_msum(qv00, vs0, vsumi0);
+ vsumi1 = vec_msum(qv01, vs4, vsumi1);
+ vsumi2 = vec_msum(qv10, vs1, vsumi2);
+ vsumi3 = vec_msum(qv11, vs5, vsumi3);
+ vsumi4 = vec_msum(qv20, vs2, vsumi4);
+ vsumi5 = vec_msum(qv21, vs6, vsumi5);
+ vsumi6 = vec_msum(qv30, vs3, vsumi6);
+ vsumi7 = vec_msum(qv31, vs7, vsumi7);
+ }
+
+ vsumi0 = vec_add(vsumi0, vsumi4);
+ vsumi1 = vec_add(vsumi1, vsumi5);
+ vsumi2 = vec_add(vsumi2, vsumi6);
+ vsumi3 = vec_add(vsumi3, vsumi7);
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+ vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+ vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+ }
+
+ vsumf0 = vec_add(vsumf0, vsumf2);
+ vsumf1 = vec_add(vsumf1, vsumf3);
+
+ vsumf0 = vec_add(vsumf0, vsumf1);
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ *s = vec_extract(vsumf0, 0);
+
+#elif defined __loongarch_asx
+
+ const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
+ const __m256i m2 = __lasx_xvreplgr2vr_b(3);
+ const __m256i m32s = __lasx_xvreplgr2vr_b(32);
+
+ __m256 acc = (__m256)__lasx_xvldi(0);
+
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+ const uint8_t * restrict q4 = x[i].ql;
+ const uint8_t * restrict qh = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
+
+ const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0);
+
+ __m256i sumi = __lasx_xvldi(0);
+
+ int is = 0;
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0));
+ const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1));
+ const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2));
+ const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3));
+ is += 4;
+
+ const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
+ const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
+ const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
+
+ const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4);
+ const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4);
+ const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4);
+ const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4);
+
+ const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0);
+ const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1);
+ const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2);
+ const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3);
+
+ const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+
+ __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0);
+ __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1);
+ __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2);
+ __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3);
+
+ __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0);
+ __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1);
+ __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2);
+ __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3);
+
+ p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
+ p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
+ p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
+ p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
+
+ p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0);
+ p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1);
+ p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2);
+ p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3);
+
+ sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
+ sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
+ }
+
+ acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
+ }
+
+ *s = hsum_float_8(acc);
+
+#else
+
+ int8_t aux8[QK_K];
+ int16_t aux16[8];
+ float sums [8];
+ int32_t aux32[8];
+ memset(sums, 0, 8*sizeof(float));
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+ const uint8_t * restrict q4 = x[i].ql;
+ const uint8_t * restrict qh = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
+ memset(aux32, 0, 8*sizeof(int32_t));
+ int8_t * restrict a = aux8;
+ for (int j = 0; j < QK_K; j += 128) {
+ for (int l = 0; l < 32; ++l) {
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+ }
+ a += 128;
+ q4 += 64;
+ qh += 32;
+ }
+ a = aux8;
+ int is = 0;
+ for (int j = 0; j < QK_K/16; ++j) {
+ int scale = x[i].scales[is++];
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+ q8 += 8; a += 8;
+ }
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+ }
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
+ *s = sumf;
+#endif
+}
+
+#if defined (__AVX__) || defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx)
+static const int8_t keven_signs_q2xs[1024] = {
+ 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
+ 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
+ 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
+ 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
+ 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
+ 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
+ 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
+ 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
+ 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
+ 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
+ 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
+ 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
+ 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
+ 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
+ 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
+ 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
+ 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
+ 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
+ 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
+ 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
+ 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
+ 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
+ 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
+ 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
+ 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
+ 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
+ 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
+ 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
+ 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
+ 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
+ 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
+ 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
+};
+#endif
+
+void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_iq2_xxs * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ uint32_t aux32[4];
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+
+ ggml_int8x16x4_t q2u;
+ ggml_int8x16x4_t q2s;
+ ggml_int8x16x4_t q8b;
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint16_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+ float sumf1 = 0, sumf2 = 0;
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
+ q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
+ q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
+ q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
+ q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
+ q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
+ q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
+ q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127))));
+ q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
+ q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
+ q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
+ q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
+ q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
+ sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
+ sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
+ }
+ sumf += d*(sumf1 + sumf2);
+ }
+ *s = 0.25f * sumf;
+
+#elif defined(__AVX2__)
+
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ uint32_t aux32[4];
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+
+ __m256 accumf = _mm256_setzero_ps();
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint16_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+ __m256i sumi1 = _mm256_setzero_si256();
+ __m256i sumi2 = _mm256_setzero_si256();
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
+ signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
+ const uint16_t ls1 = aux32[1] >> 28;
+ const uint16_t ls2 = aux32[3] >> 28;
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
+ sumi1 = _mm256_add_epi32(sumi1, p1);
+ sumi2 = _mm256_add_epi32(sumi2, p2);
+ }
+
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+ }
+
+ *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ uint32_t aux32[4];
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+
+ __m256 accumf = _mm256_setzero_ps();
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint16_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+ __m128i sumi1_0 = _mm_setzero_si128();
+ __m128i sumi1_1 = _mm_setzero_si128();
+ __m128i sumi2_0 = _mm_setzero_si128();
+ __m128i sumi2_1 = _mm_setzero_si128();
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
+ const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
+ const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]);
+ const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
+ const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]);
+ const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
+ const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
+ const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
+ const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]);
+ const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
+ const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
+ const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
+ const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
+ const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+ const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+ const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+ const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+ const uint16_t ls1 = aux32[1] >> 28;
+ const uint16_t ls2 = aux32[3] >> 28;
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
+ sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+ sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+ sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+ sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+ }
+
+ accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+ }
+
+ *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__POWER9_VECTOR__)
+ const vector int v0 = vec_splats((int32_t)0);
+ vector float vsumf0 = vec_splats(0.0f);
+ vector float vsumf1 = vec_splats(0.0f);
+ vector float vsumf2 = vec_splats(0.0f);
+ vector float vsumf3 = vec_splats(0.0f);
+
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ for (int i = 0; i < nb; ++i) {
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+ vector float vyd = vec_splats(y[i].d);
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector signed int vsumi0 = v0;
+ vector signed int vsumi1 = v0;
+ vector signed int vsumi2 = v0;
+ vector signed int vsumi3 = v0;
+
+ const uint16_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ for (int j = 0; j < QK_K/32; j += 2) {
+ __builtin_prefetch(q2, 0, 1);
+ __builtin_prefetch(q8, 0, 1);
+
+ uint32_t aux32[4];
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+
+ memcpy(aux32, q2, 4*sizeof(uint32_t));
+ q2 += 8;
+
+ vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1])};
+ vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3])};
+ vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9])};
+ vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11])};
+
+ vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127))};
+ vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127))};
+ vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127))};
+ vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127))};
+
+ vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0);
+ vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1);
+ vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2);
+ vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3);
+
+ vector signed char q8y0 = vec_xl( 0, q8);
+ vector signed char q8y1 = vec_xl(16, q8);
+ vector signed char q8y2 = vec_xl(32, q8);
+ vector signed char q8y3 = vec_xl(48, q8);
+ q8 += 64;
+
+ vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0));
+ vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1));
+ vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2));
+ vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3));
+
+ const uint16_t ls0 = aux32[1] >> 28;
+ const uint16_t ls1 = aux32[3] >> 28;
+
+ vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1));
+ vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1));
+
+ vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+ vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+ vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+ vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+ }
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+ vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+ vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+ }
+
+ vsumf0 = vec_add(vsumf0, vsumf2);
+ vsumf1 = vec_add(vsumf1, vsumf3);
+
+ vsumf0 = vec_add(vsumf0, vsumf1);
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ *s = 0.125f * vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ uint32_t aux32[4];
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+
+ __m256 accumf = (__m256)__lasx_xvldi(0);
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint16_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+ __m256i sumi1 = __lasx_xvldi(0);
+ __m256i sumi2 = __lasx_xvldi(0);
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+ const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
+
+ const __m256i q2_1 = lasx_set_d(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
+ const __m256i q2_2 = lasx_set_d(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
+ const __m256i s2_1 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
+ const __m256i s2_2 = lasx_set_d(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
+ signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
+ const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1);
+ const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2);
+ const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
+ const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2);
+ const uint16_t ls1 = aux32[1] >> 28;
+ const uint16_t ls2 = aux32[3] >> 28;
+ const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
+ const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
+ sumi1 = __lasx_xvadd_w(sumi1, p1);
+ sumi2 = __lasx_xvadd_w(sumi2, p2);
+ }
+
+ accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
+ }
+
+ *s = 0.125f * hsum_float_8(accumf);
+
+#else
+
+ uint32_t aux32[2];
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+
+ float sumf = 0.f;
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint16_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+ int32_t bsum = 0;
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ memcpy(aux32, q2, 2*sizeof(uint32_t));
+ q2 += 4;
+ const uint32_t ls = 2*(aux32[1] >> 28) + 1;
+ int32_t sumi = 0;
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
+ const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
+ for (int j = 0; j < 8; ++j) {
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+ }
+ q8 += 8;
+ }
+ bsum += sumi * ls;
+ }
+ sumf += d * bsum;
+ }
+ *s = 0.125f * sumf;
+#endif
+}
+
+void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_iq2_xs * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ ggml_int8x16x4_t q2u;
+ ggml_int8x16x4_t q2s;
+ ggml_int8x16x4_t q8b;
+
+ int32x4x4_t scales32;
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint16_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+ const uint8x8_t scales8 = vld1_u8(x[i].scales);
+ const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
+ const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
+ uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
+ scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
+ const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
+ const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
+ scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
+ scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
+ scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
+ scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
+ int32x4_t sumi = vdupq_n_s32(0);
+ for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+ q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
+ q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
+ q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
+ q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511))));
+ q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9))));
+ q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9))));
+ q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9))));
+ q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9))));
+ q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
+ q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
+ q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
+ q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
+ const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]);
+ const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
+ const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
+ const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
+ const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
+ sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
+ q2 += 8;
+ }
+ sumf += d*vaddvq_s32(sumi);
+ }
+ *s = 0.125f * sumf;
+
+#elif defined(__AVX2__)
+
+ const __m256i mone = _mm256_set1_epi8(1);
+ static const char block_sign_shuffle_mask_1[32] = {
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
+ 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
+ };
+ static const char block_sign_shuffle_mask_2[32] = {
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
+ 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
+ };
+ static const uint8_t bit_selector_mask_bytes[32] = {
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ };
+
+ const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
+ const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
+ const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
+
+ static const uint8_t k_bit_helper[32] = {
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+ };
+ const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
+ const __m256i m511 = _mm256_set1_epi16(511);
+ const __m128i m4 = _mm_set1_epi8(0xf);
+ const __m128i m1 = _mm_set1_epi8(1);
+
+ uint64_t aux64;
+
+ // somewhat hacky, but gives a significant boost in performance
+ __m256i aux_gindex;
+ const uint16_t * gindex = (const uint16_t *)&aux_gindex;
+
+ __m256 accumf = _mm256_setzero_ps();
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint16_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ memcpy(&aux64, x[i].scales, 8);
+ __m128i stmp = _mm_set1_epi64x(aux64);
+ stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
+ const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
+
+ __m256i sumi1 = _mm256_setzero_si256();
+ __m256i sumi2 = _mm256_setzero_si256();
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
+
+ const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16;
+ aux_gindex = _mm256_and_si256(q2_data, m511);
+
+ const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9);
+ const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13);
+ const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper);
+
+ const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
+ const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits);
+
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+
+ const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
+ iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
+ const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
+ iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
+ const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
+ iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
+ const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
+ iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
+
+ const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);
+ const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
+ const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l);
+ const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h);
+
+ __m256i signs;
+ signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
+
+ signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
+
+ signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
+ const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));
+
+ signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
+ const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));
+
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
+ const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3);
+ const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4);
+
+ const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
+ const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
+ const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)));
+ const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)));
+
+ sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
+ sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
+ sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3));
+ sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4));
+ }
+
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+ }
+
+ *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+ const __m128i mone = _mm_set1_epi8(1);
+ static const char block_sign_shuffle_mask_1[32] = {
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
+ 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
+ };
+ static const char block_sign_shuffle_mask_2[32] = {
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
+ 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
+ };
+ static const uint8_t bit_selector_mask_bytes[32] = {
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ };
+
+ const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes);
+ const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1);
+ const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1);
+ const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1);
+ const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2);
+ const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1);
+
+ static const uint8_t k_bit_helper[32] = {
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+ };
+ const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper);
+ const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1);
+ const __m128i m511 = _mm_set1_epi16(511);
+ const __m128i m4 = _mm_set1_epi8(0xf);
+ const __m128i m1 = _mm_set1_epi8(1);
+
+ uint64_t aux64;
+
+ // somewhat hacky, but gives a significant boost in performance
+ __m256i aux_gindex;
+ const uint16_t * gindex = (const uint16_t *)&aux_gindex;
+
+ __m256 accumf = _mm256_setzero_ps();
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint16_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ memcpy(&aux64, x[i].scales, 8);
+ __m128i stmp = _mm_set1_epi64x(aux64);
+ stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
+ const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
+
+ __m128i sumi1_0 = _mm_setzero_si128();
+ __m128i sumi1_1 = _mm_setzero_si128();
+ __m128i sumi2_0 = _mm_setzero_si128();
+ __m128i sumi2_1 = _mm_setzero_si128();
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
+
+ const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2);
+ const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16;
+ aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511));
+
+ const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9);
+ const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9);
+ const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13);
+ const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13);
+ const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0);
+ const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1);
+
+ const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0);
+ const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1);
+ const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0);
+ const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1);
+
+ const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+
+ const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
+ const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]);
+ const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
+ const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]);
+ const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]);
+ const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]);
+ const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
+ const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]);
+
+ // AVX2 full_signs_1 is full_sign_bits_0 here
+ // AVX2 full_signs_2 is full_sign_bits_1 here
+ __m128i signs_0, signs_1;
+ signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0);
+ signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1);
+ signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+ signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+ const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone));
+ const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone));
+
+ signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0);
+ signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1);
+ signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+ signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+ const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone));
+ const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone));
+
+ signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0);
+ signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1);
+ signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+ signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+ const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone));
+ const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone));
+
+ signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0);
+ signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1);
+ signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+ signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+ const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone));
+ const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone));
+
+ const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+ const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+ const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+ const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+ const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0);
+ const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1);
+ const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0);
+ const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1);
+
+ __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0));
+ const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp);
+ const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+ sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1));
+ const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp);
+ const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+ sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2));
+ const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp);
+ const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+ sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3));
+ const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp);
+ const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+
+ sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0));
+ sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1));
+ sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0));
+ sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1));
+ sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0));
+ sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1));
+ sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0));
+ sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1));
+ }
+
+ accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+ }
+
+ *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__loongarch_asx)
+
+ const __m256i mone = __lasx_xvreplgr2vr_b(1);
+ static const char block_sign_shuffle_mask_1[32] = {
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
+ 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
+ };
+ static const char block_sign_shuffle_mask_2[32] = {
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
+ 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
+ };
+ static const uint8_t bit_selector_mask_bytes[32] = {
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ };
+
+ const __m256i bit_selector_mask = __lasx_xvld((const __m256i*)bit_selector_mask_bytes, 0);
+ const __m256i block_sign_shuffle_1 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_1, 0);
+ const __m256i block_sign_shuffle_2 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_2, 0);
+
+ static const uint8_t k_bit_helper[32] = {
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+ };
+ const __m256i bit_helper = __lasx_xvld((const __m256i*)k_bit_helper, 0);
+ const __m256i m511 = __lasx_xvreplgr2vr_h(511);
+ const __m128i m4 = __lsx_vreplgr2vr_b(0xf);
+ const __m128i m1 = __lsx_vreplgr2vr_b(1);
+
+ uint64_t aux64;
+
+ // somewhat hacky, but gives a significant boost in performance
+ __m256i aux_gindex;
+ const uint16_t * gindex = (const uint16_t *)&aux_gindex;
+
+ __m256 accumf = (__m256)__lasx_xvldi(0);
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint16_t * restrict q2 = x[i].qs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ memcpy(&aux64, x[i].scales, 8);
+ __m128i stmp = __lsx_vreplgr2vr_d(aux64);
+ stmp = __lsx_vilvl_b( __lsx_vand_v(__lsx_vsrli_h(stmp, 4), m4), __lsx_vand_v(stmp, m4));
+ const __m128i scales = __lsx_vadd_b(__lsx_vslli_h(stmp, 1), m1);
+
+ __m256i sumi1 = __lasx_xvldi(0);
+ __m256i sumi2 = __lasx_xvldi(0);
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
+
+ const __m256i q2_data = __lasx_xvld((const __m256i*)q2, 0); q2 += 16;
+ aux_gindex = __lasx_xvand_v(q2_data, m511);
+
+ const __m256i partial_sign_bits = __lasx_xvsrli_h(q2_data, 9);
+ const __m256i partial_sign_bits_upper = __lasx_xvsrli_h(q2_data, 13);
+ const __m256i partial_sign_bits_for_counting = __lasx_xvxor_v(partial_sign_bits, partial_sign_bits_upper);
+
+ const __m256i odd_bits = lasx_shuffle_b(bit_helper, partial_sign_bits_for_counting);
+ const __m256i full_sign_bits = __lasx_xvor_v(partial_sign_bits, odd_bits);
+
+ const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+ const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+ const __m256i q8_3 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+ const __m256i q8_4 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+
+ const __m256i q2_1 = lasx_set_d(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
+ iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
+ const __m256i q2_2 = lasx_set_d(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
+ iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
+ const __m256i q2_3 = lasx_set_d(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
+ iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
+ const __m256i q2_4 = lasx_set_d(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
+ iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
+
+ const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0);
+ const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1);
+ const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l);
+ const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h);
+
+ __m256i signs;
+ signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1);
+ signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
+ const __m256i q8s_1 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_1);
+
+ signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_2);
+ signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
+ const __m256i q8s_2 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_2);
+
+ signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_1);
+ signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
+ const __m256i q8s_3 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_3);
+
+ signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_2);
+ signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
+ const __m256i q8s_4 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_4);
+
+ const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
+ const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2);
+ const __m256i dot3 = lasx_maddubs_h(q2_3, q8s_3);
+ const __m256i dot4 = lasx_maddubs_h(q2_4, q8s_4);
+
+ const __m256i sc1 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+0)));
+ const __m256i sc2 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+1)));
+ const __m256i sc3 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+2)));
+ const __m256i sc4 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+3)));
+
+ sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot1, sc1));
+ sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot2, sc2));
+ sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot3, sc3));
+ sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot4, sc4));
+ }
+
+ accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
+
+ }
+
+ *s = 0.125f * hsum_float_8(accumf);
+#elif defined(__POWER9_VECTOR__)
+ const vector int v0 = vec_splats((int32_t)0);
+ vector float vsumf0 = vec_splats(0.0f);
+ vector float vsumf1 = vec_splats(0.0f);
+ vector float vsumf2 = vec_splats(0.0f);
+ vector float vsumf3 = vec_splats(0.0f);
+
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ for (int i = 0; i < nb; ++i) {
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+ vector float vyd = vec_splats(y[i].d);
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector signed int vsumi0 = v0;
+ vector signed int vsumi1 = v0;
+ vector signed int vsumi2 = v0;
+ vector signed int vsumi3 = v0;
+
+ const uint16_t * restrict q2 = x[i].qs;
+ const uint8_t * restrict sc = x[i].scales;
+ const int8_t * restrict q8 = y[i].qs;
+
+ for (int j = 0; j < QK_K/64; ++j) {
+ __builtin_prefetch(q2, 0, 1);
+ __builtin_prefetch(q8, 0, 1);
+
+ vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xs_grid + (q2[0] & 511)), *(const int64_t *)(iq2xs_grid + (q2[1] & 511))};
+ vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xs_grid + (q2[2] & 511)), *(const int64_t *)(iq2xs_grid + (q2[3] & 511))};
+ vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xs_grid + (q2[4] & 511)), *(const int64_t *)(iq2xs_grid + (q2[5] & 511))};
+ vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xs_grid + (q2[6] & 511)), *(const int64_t *)(iq2xs_grid + (q2[7] & 511))};
+
+ vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((q2[0] >> 9))), *(const int64_t *)(signs64 + ((q2[1] >> 9)))};
+ vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((q2[2] >> 9))), *(const int64_t *)(signs64 + ((q2[3] >> 9)))};
+ vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((q2[4] >> 9))), *(const int64_t *)(signs64 + ((q2[5] >> 9)))};
+ vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((q2[6] >> 9))), *(const int64_t *)(signs64 + ((q2[7] >> 9)))};
+ q2 += 8;
+
+ vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0);
+ vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1);
+ vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2);
+ vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3);
+
+ vector signed char q8y0 = vec_xl( 0, q8);
+ vector signed char q8y1 = vec_xl(16, q8);
+ vector signed char q8y2 = vec_xl(32, q8);
+ vector signed char q8y3 = vec_xl(48, q8);
+ q8 += 64;
+
+ vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0));
+ vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1));
+ vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2));
+ vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3));
+
+ const uint16_t ls0 = (uint16_t)(sc[0] & 0xf);
+ const uint16_t ls1 = (uint16_t)(sc[0] >> 4);
+ const uint16_t ls2 = (uint16_t)(sc[1] & 0xf);
+ const uint16_t ls3 = (uint16_t)(sc[1] >> 4);
+ sc += 2;
+
+ vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1));
+ vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1));
+ vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1));
+ vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1));
+
+ vsumi0 = vec_msum(qv0, vscales0, vsumi0);
+ vsumi1 = vec_msum(qv1, vscales1, vsumi1);
+ vsumi2 = vec_msum(qv2, vscales2, vsumi2);
+ vsumi3 = vec_msum(qv3, vscales3, vsumi3);
+ }
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+ vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+ vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+ }
+
+ vsumf0 = vec_add(vsumf0, vsumf2);
+ vsumf1 = vec_add(vsumf1, vsumf3);
+
+ vsumf0 = vec_add(vsumf0, vsumf1);
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ *s = 0.125f * vec_extract(vsumf0, 0);
+#else
+
+ float sumf = 0.f;
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint16_t * restrict q2 = x[i].qs;
+ const uint8_t * restrict sc = x[i].scales;
+ const int8_t * restrict q8 = y[i].qs;
+ int32_t bsum = 0;
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
+ const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1;
+ int32_t sumi = 0;
+ for (int l = 0; l < 2; ++l) {
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
+ for (int j = 0; j < 8; ++j) {
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+ }
+ q8 += 8;
+ }
+ bsum += sumi * ls1;
+ sumi = 0;
+ for (int l = 2; l < 4; ++l) {
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
+ const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
+ for (int j = 0; j < 8; ++j) {
+ sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+ }
+ q8 += 8;
+ }
+ bsum += sumi * ls2;
+ q2 += 4;
+ }
+ sumf += d * bsum;
+ }
+ *s = 0.125f * sumf;
+#endif
+}
+
+void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_iq2_s * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+ };
+
+ static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
+
+ const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1);
+ const uint8x16_t mask2 = vld1q_u8(k_mask2);
+ const uint8x16_t m1 = vdupq_n_u8(1);
+ const int32x4_t vzero = vdupq_n_s32(0);
+
+ uint8x16x2_t vs;
+ ggml_int8x16x4_t q2s;
+ ggml_int8x16x4_t q8b;
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+ const uint8_t * restrict qs = x[i].qs;
+ const uint8_t * restrict qh = x[i].qh;
+ const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+ const int8_t * restrict q8 = y[i].qs;
+
+ int sumi1 = 0, sumi2 = 0;
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+ q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))),
+ vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300)))));
+ q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))),
+ vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300)))));
+ q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))),
+ vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300)))));
+ q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))),
+ vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300)))));
+ qs += 8;
+
+ vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
+ vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+ vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+ vs.val[0] = vceqq_u8(vs.val[0], mask2);
+ vs.val[1] = vceqq_u8(vs.val[1], mask2);
+
+ q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]);
+ q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);
+
+ vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
+ vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+ vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+ vs.val[0] = vceqq_u8(vs.val[0], mask2);
+ vs.val[1] = vceqq_u8(vs.val[1], mask2);
+
+ signs += 4;
+
+ q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]);
+ q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]);
+
+ const int32x4_t p1 = ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]);
+ const int32x4_t p2 = ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]);
+ const int32x4_t p3 = ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]);
+ const int32x4_t p4 = ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]);
+
+ sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf));
+ sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >> 4));
+ sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf));
+ sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >> 4));
+ }
+ sumf += d*(sumi1 + sumi2);
+ }
+
+ *s = 0.125f * sumf;
+
+#elif defined(__AVX2__)
+
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+ };
+
+ static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ };
+
+ const __m128i m4 = _mm_set1_epi8(0xf);
+ const __m128i m1 = _mm_set1_epi8(1);
+
+ const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
+ const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
+
+ uint64_t aux64;
+
+ __m256 accumf = _mm256_setzero_ps();
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict qs = x[i].qs;
+ const uint8_t * restrict qh = x[i].qh;
+ const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+ const int8_t * restrict q8 = y[i].qs;
+
+ memcpy(&aux64, x[i].scales, 8);
+ const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
+ const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15
+
+ __m256i sumi1 = _mm256_setzero_si256();
+ __m256i sumi2 = _mm256_setzero_si256();
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
+ iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)],
+ iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
+ iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
+ const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
+ iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)],
+ iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
+ iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
+ qs += 8;
+
+ __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
+ const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
+ const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
+
+ aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
+ const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
+ const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
+
+ signs += 4;
+
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3
+
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0)));
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1)));
+ sumi1 = _mm256_add_epi32(sumi1, p1);
+ sumi2 = _mm256_add_epi32(sumi2, p2);
+ }
+
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+ }
+
+ *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+ };
+
+ static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ };
+
+ const __m128i m4 = _mm_set1_epi8(0xf);
+ const __m128i m1 = _mm_set1_epi8(1);
+
+ const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
+ const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
+ const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
+ const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
+
+ uint64_t aux64;
+
+ __m256 accumf = _mm256_setzero_ps();
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict qs = x[i].qs;
+ const uint8_t * restrict qh = x[i].qh;
+ const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+ const int8_t * restrict q8 = y[i].qs;
+
+ memcpy(&aux64, x[i].scales, 8);
+ const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
+ const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8);
+ const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8));
+
+ __m128i sumi1_0 = _mm_setzero_si128();
+ __m128i sumi1_1 = _mm_setzero_si128();
+ __m128i sumi2_0 = _mm_setzero_si128();
+ __m128i sumi2_1 = _mm_setzero_si128();
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
+ iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
+ const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
+ iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]);
+ const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
+ iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
+ const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
+ iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]);
+ qs += 8;
+
+ __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
+ __m128i aux128_1 = aux128_0;
+ aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+ aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+ const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+ const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+ const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
+ const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
+
+ aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
+ aux128_1 = aux128_0;
+ aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+ aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+ const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+ const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+ const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
+ const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
+
+ signs += 4;
+
+ const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+ const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+ const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+ const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0)));
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1)));
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0)));
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1)));
+ sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+ sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+ sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+ sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+ }
+
+ accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+ }
+
+ *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__POWER9_VECTOR__)
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+ };
+
+ static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
+
+ const vector int v0 = vec_splats((int32_t)0);
+
+ vector float vsumf0 = vec_splats(0.0f);
+ vector float vsumf1 = vec_splats(0.0f);
+ vector float vsumf2 = vec_splats(0.0f);
+ vector float vsumf3 = vec_splats(0.0f);
+
+ const vector unsigned char mask0 = vec_xl( 0, k_mask1);
+ const vector unsigned char mask1 = vec_xl(16, k_mask1);
+ const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2);
+
+ for (int i = 0; i < nb; ++i) {
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+ vector float vyd = vec_splats(y[i].d);
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector signed int vsumi0 = v0;
+ vector signed int vsumi1 = v0;
+ vector signed int vsumi2 = v0;
+ vector signed int vsumi3 = v0;
+
+ const uint8_t * restrict q2 = x[i].qs;
+ const uint8_t * restrict qh = x[i].qh;
+ const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+ const uint8_t * restrict sc = x[i].scales;
+ const int8_t * restrict q8 = y[i].qs;
+
+ for (int j = 0; j < QK_K/32; j += 2) {
+ __builtin_prefetch(q2, 0, 1);
+ __builtin_prefetch(q8, 0, 1);
+
+ vector signed long long aux64x2_0 = {*(const int64_t *)(iq2s_grid + (q2[0] | ((qh[0] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[1] | ((qh[0] << 6) & 0x300)))};
+ vector signed long long aux64x2_1 = {*(const int64_t *)(iq2s_grid + (q2[2] | ((qh[0] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[3] | ((qh[0] << 2) & 0x300)))};
+ vector signed long long aux64x2_2 = {*(const int64_t *)(iq2s_grid + (q2[4] | ((qh[1] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[5] | ((qh[1] << 6) & 0x300)))};
+ vector signed long long aux64x2_3 = {*(const int64_t *)(iq2s_grid + (q2[6] | ((qh[1] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[7] | ((qh[1] << 2) & 0x300)))};
+ q2 += 8;
+ qh += 2;
+
+ vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]);
+ vector signed char vsigns23 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]);
+ signs += 4;
+
+ vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0);
+ vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1);
+ vector signed char vsigns2 = vec_perm(vsigns23, vsigns23, mask0);
+ vector signed char vsigns3 = vec_perm(vsigns23, vsigns23, mask1);
+
+ vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2);
+ vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2);
+ vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2);
+ vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2);
+
+ vector signed char q2x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux64x2_0), vsigns0);
+ vector signed char q2x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux64x2_1), vsigns1);
+ vector signed char q2x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux64x2_2), vsigns2);
+ vector signed char q2x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux64x2_3), vsigns3);
+
+ vector signed char q8y0 = vec_xl( 0, q8);
+ vector signed char q8y1 = vec_xl(16, q8);
+ vector signed char q8y2 = vec_xl(32, q8);
+ vector signed char q8y3 = vec_xl(48, q8);
+ q8 += 64;
+
+ vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0));
+ vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1));
+ vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2));
+ vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3));
+
+ const uint16_t ls0 = (uint16_t)(sc[0] & 0xf);
+ const uint16_t ls1 = (uint16_t)(sc[0] >> 4);
+ const uint16_t ls2 = (uint16_t)(sc[1] & 0xf);
+ const uint16_t ls3 = (uint16_t)(sc[1] >> 4);
+ sc += 2;
+
+ vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1));
+ vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1));
+ vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1));
+ vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1));
+
+ vsumi0 = vec_msum(qv0, vscales0, vsumi0);
+ vsumi1 = vec_msum(qv1, vscales1, vsumi1);
+ vsumi2 = vec_msum(qv2, vscales2, vsumi2);
+ vsumi3 = vec_msum(qv3, vscales3, vsumi3);
+ }
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+ vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+ vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+ }
+
+ vsumf0 = vec_add(vsumf0, vsumf2);
+ vsumf1 = vec_add(vsumf1, vsumf3);
+
+ vsumf0 = vec_add(vsumf0, vsumf1);
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ *s = 0.125f * vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+ };
+
+ static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ };
+
+
+ const __m128i m4 = __lsx_vreplgr2vr_b(0xf);
+ const __m128i m1 = __lsx_vreplgr2vr_b(1);
+
+ const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0);
+ const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0);
+ uint64_t aux64;
+
+ __m256 accumf = (__m256)__lasx_xvldi(0);
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict qs = x[i].qs;
+ const uint8_t * restrict qh = x[i].qh;
+ const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+ const int8_t * restrict q8 = y[i].qs;
+
+ __m128i tmp1;
+ memcpy(&aux64, x[i].scales, 8);
+ tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64, 0);
+ tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64 >> 4, 1);
+ const __m128i scales8 = __lsx_vadd_b(__lsx_vslli_h(__lsx_vand_v(tmp1, m4), 1), m1);
+ const __m256i scales16 = lasx_ext8_16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15
+
+ __m256i sumi1 = __lasx_xvldi(0);
+ __m256i sumi2 = __lasx_xvldi(0);
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+ const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+ const __m256i q2_1 = lasx_set_d(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
+ iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)],
+ iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
+ iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
+ const __m256i q2_2 = lasx_set_d(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
+ iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)],
+ iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
+ iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
+ qs += 8;
+
+ __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | ((uint32_t) signs[1] << 16));
+ aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
+ const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2);
+ const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1);
+
+ aux256 = __lasx_xvreplgr2vr_w(signs[2] | ((uint32_t) signs[3] << 16));
+ aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
+ const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2);
+ const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2);
+
+ signs += 4;
+
+ const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1
+ const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3
+
+ const __m256i p1 = lasx_madd_h(dot1, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+0)));
+ const __m256i p2 = lasx_madd_h(dot2, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+1)));
+ sumi1 = __lasx_xvadd_w(sumi1, p1);
+ sumi2 = __lasx_xvadd_w(sumi2, p2);
+ }
+
+ accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
+ }
+
+ *s = 0.125f * hsum_float_8(accumf);
+
+#else
+
+ float sumf = 0;
+ for (int i = 0; i < nb; i++) {
+
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const int8_t * q8 = y[i].qs;
+ const uint8_t * qs = x[i].qs;
+ const uint8_t * qh = x[i].qh;
+ const uint8_t * signs = qs + QK_K/8;
+
+ int bsum = 0;
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf);
+ int ls2 = 1 + 2*(x[i].scales[ib32] >> 4);
+ int sumi1 = 0, sumi2 = 0;
+ for (int l = 0; l < 2; ++l) {
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
+ for (int j = 0; j < 8; ++j) {
+ sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
+ }
+ q8 += 8;
+ }
+ for (int l = 2; l < 4; ++l) {
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
+ for (int j = 0; j < 8; ++j) {
+ sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
+ }
+ q8 += 8;
+ }
+ bsum += ls1 * sumi1 + ls2 * sumi2;
+ qs += 4;
+ signs += 4;
+ }
+
+ sumf += d * bsum;
+ }
+
+ *s = 0.125f * sumf;
+
+#endif
+
+}
+
+void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_iq3_xxs * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ uint32_t aux32[2];
+
+ ggml_int8x16x4_t q3s;
+ ggml_int8x16x4_t q8b;
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
+ const int8_t * restrict q8 = y[i].qs;
+ float sumf1 = 0, sumf2 = 0;
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+ memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
+ const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]);
+ const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]);
+ const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]);
+ const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]);
+ q3 += 16;
+ q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
+ q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
+ q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
+ q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
+ q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
+ q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
+ q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
+ q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
+ sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
+ sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
+ }
+ sumf += d*(sumf1 + sumf2);
+ }
+ *s = 0.5f * sumf;
+
+#elif defined(__AVX2__)
+
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ uint32_t aux32[2];
+
+ __m256 accumf = _mm256_setzero_ps();
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
+ const int8_t * restrict q8 = y[i].qs;
+ __m256i sumi1 = _mm256_setzero_si256();
+ __m256i sumi2 = _mm256_setzero_si256();
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+ q3 += 8;
+ const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+ q3 += 8;
+ memcpy(aux32, gas, 8); gas += 8;
+ const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
+ signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
+ const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
+ const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
+ const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
+ const uint16_t ls1 = aux32[0] >> 28;
+ const uint16_t ls2 = aux32[1] >> 28;
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
+ sumi1 = _mm256_add_epi32(sumi1, p1);
+ sumi2 = _mm256_add_epi32(sumi2, p2);
+ }
+
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+ }
+
+ *s = 0.25f * hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ uint32_t aux32[2];
+
+ __m256 accumf = _mm256_setzero_ps();
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
+ const int8_t * restrict q8 = y[i].qs;
+ __m128i sumi1_0 = _mm_setzero_si128();
+ __m128i sumi1_1 = _mm_setzero_si128();
+ __m128i sumi2_0 = _mm_setzero_si128();
+ __m128i sumi2_1 = _mm_setzero_si128();
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+ const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
+ q3 += 8;
+ const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+ const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
+ q3 += 8;
+ memcpy(aux32, gas, 8); gas += 8;
+ const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
+ const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]);
+ const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
+ const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
+ const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
+ const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
+ const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
+ const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
+ const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+ const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+ const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+ const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+ const uint16_t ls1 = aux32[0] >> 28;
+ const uint16_t ls2 = aux32[1] >> 28;
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
+ sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+ sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+ sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+ sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+ }
+
+ accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+ }
+
+ *s = 0.25f * hsum_float_8(accumf);
+
+#elif defined(__POWER9_VECTOR__)
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ const vector int v0 = vec_splats((int32_t)0);
+
+ vector float vsumf0 = vec_splats(0.0f);
+ vector float vsumf1 = vec_splats(0.0f);
+ vector float vsumf2 = vec_splats(0.0f);
+ vector float vsumf3 = vec_splats(0.0f);
+
+ for (int i = 0; i < nb; ++i) {
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+ vector float vyd = vec_splats(y[i].d);
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector signed int vsumi0 = v0;
+ vector signed int vsumi1 = v0;
+ vector signed int vsumi2 = v0;
+ vector signed int vsumi3 = v0;
+
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint32_t * restrict signs = (const uint32_t *)(x[i].qs + QK_K/4);
+ const int8_t * restrict q8 = y[i].qs;
+
+#pragma GCC unroll 1
+ for (int j = 0; j < QK_K/32; j += 2) {
+ __builtin_prefetch(q3, 0, 1);
+ __builtin_prefetch(q8, 0, 1);
+
+ vector unsigned int aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};
+ vector unsigned int aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};
+ vector unsigned int aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};
+ vector unsigned int aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};
+ q3 += 16;
+
+ vector unsigned long long aux64x2_0 = {(uint64_t)(signs64[(signs[0] >> 0) & 127]), (uint64_t)(signs64[(signs[0] >> 7) & 127])};
+ vector unsigned long long aux64x2_1 = {(uint64_t)(signs64[(signs[0] >> 14) & 127]), (uint64_t)(signs64[(signs[0] >> 21) & 127])};
+ vector unsigned long long aux64x2_2 = {(uint64_t)(signs64[(signs[1] >> 0) & 127]), (uint64_t)(signs64[(signs[1] >> 7) & 127])};
+ vector unsigned long long aux64x2_3 = {(uint64_t)(signs64[(signs[1] >> 14) & 127]), (uint64_t)(signs64[(signs[1] >> 21) & 127])};
+
+ vector signed char q3x0 = vec_mul((vector signed char)aux64x2_0, (vector signed char)aux32x4_0);
+ vector signed char q3x1 = vec_mul((vector signed char)aux64x2_1, (vector signed char)aux32x4_1);
+ vector signed char q3x2 = vec_mul((vector signed char)aux64x2_2, (vector signed char)aux32x4_2);
+ vector signed char q3x3 = vec_mul((vector signed char)aux64x2_3, (vector signed char)aux32x4_3);
+
+ vector signed char q8y0 = vec_xl( 0, q8);
+ vector signed char q8y1 = vec_xl(16, q8);
+ vector signed char q8y2 = vec_xl(32, q8);
+ vector signed char q8y3 = vec_xl(48, q8);
+ q8 += 64;
+
+ vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0));
+ vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1));
+ vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2));
+ vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3));
+
+ const uint16_t ls0 = (uint16_t)(signs[0] >> 28);
+ const uint16_t ls1 = (uint16_t)(signs[1] >> 28);
+ signs += 2;
+
+ vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));
+ vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));
+
+ vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+ vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+ vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+ vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+ }
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+ vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+ vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+ }
+
+ vsumf0 = vec_add(vsumf0, vsumf2);
+ vsumf1 = vec_add(vsumf1, vsumf3);
+
+ vsumf0 = vec_add(vsumf0, vsumf1);
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ *s = 0.25f * vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+ uint32_t aux32[2];
+
+ __m256 accumf = (__m256)__lasx_xvldi(0);
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
+ const int8_t * restrict q8 = y[i].qs;
+ __m256i sumi1 = __lasx_xvldi(0);
+ __m256i sumi2 = __lasx_xvldi(0);
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+ const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+ const __m256i q2_1 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+ q3 += 8;
+ const __m256i q2_2 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+ q3 += 8;
+ memcpy(aux32, gas, 8); gas += 8;
+
+ const __m256i s2_1 = lasx_set_d(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
+ signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
+ const __m256i s2_2 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
+ const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1);
+ const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2);
+ const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
+ const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2);
+ const uint16_t ls1 = aux32[0] >> 28;
+ const uint16_t ls2 = aux32[1] >> 28;
+
+ const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
+ const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
+ sumi1 = __lasx_xvadd_w(sumi1, p1);
+ sumi2 = __lasx_xvadd_w(sumi2, p2);
+ }
+
+ accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
+ }
+
+ *s = 0.25f * hsum_float_8(accumf);
+
+#else
+
+ uint32_t aux32;
+
+ float sumf = 0.f;
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
+ const int8_t * restrict q8 = y[i].qs;
+ int32_t bsum = 0;
+ for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+ memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
+ const uint32_t ls = 2*(aux32 >> 28) + 1;
+ int32_t sumi = 0;
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
+ for (int j = 0; j < 4; ++j) {
+ sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
+ sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
+ }
+ q8 += 8;
+ }
+ q3 += 8;
+ bsum += sumi * ls;
+ }
+ sumf += d * bsum;
+ }
+ *s = 0.25f * sumf;
+#endif
+}
+
+void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_iq3_s * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+ typedef union {
+ uint16x8_t vec_index;
+ uint16_t index[8];
+ } vec_index_t;
+
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+ };
+
+ static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
+
+ static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
+
+ const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1);
+ const uint8x16_t mask2 = vld1q_u8(k_mask2);
+
+ const int16x8_t hshift = vld1q_s16(k_shift);
+ const uint16x8_t m256 = vdupq_n_u16(256);
+ const uint8x16_t m1 = vdupq_n_u8(1);
+
+ uint8x16x2_t vs;
+ ggml_int8x16x4_t q3s;
+ ggml_int8x16x4_t q8b;
+ vec_index_t idx;
+
+ uint32_t scales32[2];
+ const uint8_t * scales8 = (const uint8_t *)scales32;
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict qs = x[i].qs;
+ const uint8_t * restrict qh = x[i].qh;
+ const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
+ const int8_t * restrict q8 = y[i].qs;
+
+ memcpy(scales32, x[i].scales, 4);
+ scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
+ scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
+
+ int sumi1 = 0, sumi2 = 0;
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+ const uint8x16_t idx_l = vld1q_u8(qs); qs += 16;
+ idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256));
+ const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
+ iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);
+ const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
+ iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
+ idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256));
+ const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
+ iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);
+ const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
+ iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
+
+
+ vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
+ vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+ vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+ vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
+ vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
+
+ q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));
+ q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));
+
+ vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
+ vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+ vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+ vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
+ vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
+
+ signs += 4;
+
+ q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2));
+ q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3));
+
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
+
+ sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0];
+ sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4];
+ }
+ sumf += d*(sumi1 + sumi2);
+ }
+ *s = sumf;
+
+#elif defined(__AVX2__)
+
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+ };
+
+ static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ };
+
+ const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
+ const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
+
+ const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
+ const __m256i idx_mask = _mm256_set1_epi32(256);
+
+ typedef union {
+ __m256i vec[2];
+ uint32_t index[16];
+ } index_t;
+
+ index_t idx;
+
+ __m256 accumf = _mm256_setzero_ps();
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict qs = x[i].qs;
+ const uint8_t * restrict qh = x[i].qh;
+ const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
+ const int8_t * restrict q8 = y[i].qs;
+ __m256i sumi1 = _mm256_setzero_si256();
+ __m256i sumi2 = _mm256_setzero_si256();
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16;
+ idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]);
+ idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]);
+ idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask);
+ idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask);
+ idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l)));
+ idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1)));
+
+ // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.
+ //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);
+ //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);
+ const __m256i q2_1 = _mm256_set_epi32(
+ iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],
+ iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]
+ );
+ const __m256i q2_2 = _mm256_set_epi32(
+ iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],
+ iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]
+ );
+
+ __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
+ const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
+ const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
+
+ aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16));
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
+ const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
+ const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
+
+ signs += 4;
+
+ const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
+ const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
+ const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
+ const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
+ sumi1 = _mm256_add_epi32(sumi1, p1);
+ sumi2 = _mm256_add_epi32(sumi2, p2);
+ }
+
+ accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+ }
+
+ *s = hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+ };
+
+ static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ };
+
+ const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
+ const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
+ const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
+ const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
+
+ const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256);
+ const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16);
+ const __m128i idx_mask = _mm_set1_epi32(256);
+
+ typedef union {
+ __m128i vec[4];
+ uint32_t index[16];
+ } index_t;
+
+ index_t idx;
+
+ __m256 accumf = _mm256_setzero_ps();
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict qs = x[i].qs;
+ const uint8_t * restrict qh = x[i].qh;
+ const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
+ const int8_t * restrict q8 = y[i].qs;
+ __m128i sumi1_0 = _mm_setzero_si128();
+ __m128i sumi1_1 = _mm_setzero_si128();
+ __m128i sumi2_0 = _mm_setzero_si128();
+ __m128i sumi2_1 = _mm_setzero_si128();
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);
+ const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp);
+ const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16;
+ idx.vec[0] = _mm_set1_epi32(qh[ib32+0]);
+ idx.vec[1] = idx.vec[0];
+ idx.vec[2] = _mm_set1_epi32(qh[ib32+1]);
+ idx.vec[3] = idx.vec[2];
+
+ idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask);
+ idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask);
+ idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask);
+ idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask);
+
+ idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0));
+ idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
+ idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1));
+ idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));
+
+ const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]);
+ const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
+ const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]);
+ const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);
+
+ __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16));
+ __m128i aux128_1 = aux128_0;
+ aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+ aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+ const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+ const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+ const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
+ const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
+
+ aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16));
+ aux128_1 = aux128_0;
+ aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+ aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+ const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+ const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+ const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
+ const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
+
+ signs += 4;
+
+ const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+ const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+ const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+ const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+ const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
+ const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
+ sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+ sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+ sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+ sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+ }
+
+ accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+ }
+
+ *s = hsum_float_8(accumf);
+
+#elif defined(__POWER9_VECTOR__)
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+ };
+
+ static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
+
+ const vector int v0 = vec_splats((int32_t)0);
+
+ vector float vsumf0 = vec_splats(0.0f);
+ vector float vsumf1 = vec_splats(0.0f);
+ vector float vsumf2 = vec_splats(0.0f);
+ vector float vsumf3 = vec_splats(0.0f);
+
+ const vector unsigned char mask0 = vec_xl( 0, k_mask1);
+ const vector unsigned char mask1 = vec_xl(16, k_mask1);
+ const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2);
+
+ for (int i = 0; i < nb; ++i) {
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+ vector float vyd = vec_splats(y[i].d);
+ vector float vd = vec_mul(vxd, vyd);
+
+ const uint8_t * restrict q3 = x[i].qs;
+ const uint8_t * restrict qh = x[i].qh;
+ const uint16_t * restrict signs = (const uint16_t *)(x[i].signs);
+ const uint8_t * restrict sc = x[i].scales;
+ const int8_t * restrict q8 = y[i].qs;
+
+ vector signed int vsumi0 = v0;
+ vector signed int vsumi1 = v0;
+ vector signed int vsumi2 = v0;
+ vector signed int vsumi3 = v0;
+
+ for (int j = 0; j < QK_K/32; j += 2) {
+ __builtin_prefetch(q3, 0, 1);
+ __builtin_prefetch(q8, 0, 1);
+
+ vector unsigned int aux32x4_0 = {iq3s_grid[q3[ 0] | ((qh[0] << 8) & 256)], iq3s_grid[q3[ 1] | ((qh[0] << 7) & 256)],
+ iq3s_grid[q3[ 2] | ((qh[0] << 6) & 256)], iq3s_grid[q3[ 3] | ((qh[0] << 5) & 256)]};
+ vector unsigned int aux32x4_1 = {iq3s_grid[q3[ 4] | ((qh[0] << 4) & 256)], iq3s_grid[q3[ 5] | ((qh[0] << 3) & 256)],
+ iq3s_grid[q3[ 6] | ((qh[0] << 2) & 256)], iq3s_grid[q3[ 7] | ((qh[0] << 1) & 256)]};
+ vector unsigned int aux32x4_2 = {iq3s_grid[q3[ 8] | ((qh[1] << 8) & 256)], iq3s_grid[q3[ 9] | ((qh[1] << 7) & 256)],
+ iq3s_grid[q3[10] | ((qh[1] << 6) & 256)], iq3s_grid[q3[11] | ((qh[1] << 5) & 256)]};
+ vector unsigned int aux32x4_3 = {iq3s_grid[q3[12] | ((qh[1] << 4) & 256)], iq3s_grid[q3[13] | ((qh[1] << 3) & 256)],
+ iq3s_grid[q3[14] | ((qh[1] << 2) & 256)], iq3s_grid[q3[15] | ((qh[1] << 1) & 256)]};
+ q3 += 16;
+ qh += 2;
+
+ vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]);
+ vector signed char vsigns02 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]);
+ signs += 4;
+
+ vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0);
+ vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1);
+ vector signed char vsigns2 = vec_perm(vsigns02, vsigns02, mask0);
+ vector signed char vsigns3 = vec_perm(vsigns02, vsigns02, mask1);
+
+ vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2);
+ vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2);
+ vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2);
+ vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2);
+
+ vector signed char q3x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux32x4_0), vsigns0);
+ vector signed char q3x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux32x4_1), vsigns1);
+ vector signed char q3x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux32x4_2), vsigns2);
+ vector signed char q3x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux32x4_3), vsigns3);
+
+ vector signed char q8y0 = vec_xl( 0, q8);
+ vector signed char q8y1 = vec_xl(16, q8);
+ vector signed char q8y2 = vec_xl(32, q8);
+ vector signed char q8y3 = vec_xl(48, q8);
+ q8 += 64;
+
+ vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0));
+ vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1));
+ vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2));
+ vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3));
+
+ const uint16_t ls0 = (uint16_t)(sc[0] & 0xf);
+ const uint16_t ls1 = (uint16_t)(sc[0] >> 4);
+ sc ++;
+
+ vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));
+ vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));
+
+ vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+ vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+ vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+ vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+ }
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+ vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+ vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+ }
+
+ vsumf0 = vec_add(vsumf0, vsumf2);
+ vsumf1 = vec_add(vsumf1, vsumf3);
+
+ vsumf0 = vec_add(vsumf0, vsumf1);
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ *s = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+ };
+
+ static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+ };
+
+ const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0);
+ const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0);
+
+ __m256i idx_shift = lasx_set_w(1, 2, 3, 4, 5, 6, 7, 8);
+ const __m256i idx_mask = __lasx_xvreplgr2vr_w(256);
+
+ typedef union {
+ __m256i vec[2];
+ uint32_t index[16];
+ } index_t;
+
+ index_t idx;
+
+ __m256 accumf = (__m256)__lasx_xvldi(0);
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict qs = x[i].qs;
+ const uint8_t * restrict qh = x[i].qh;
+ const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
+ const int8_t * restrict q8 = y[i].qs;
+ __m256i sumi1 = __lasx_xvldi(0);
+ __m256i sumi2 = __lasx_xvldi(0);
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+ const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+ const __m256i idx_l = lasx_extu8_16(__lsx_vld(qs, 0)); qs += 16;
+ idx.vec[0] = __lasx_xvreplgr2vr_w(qh[ib32+0]);
+ idx.vec[1] = __lasx_xvreplgr2vr_w(qh[ib32+1]);
+ idx.vec[0] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[0], idx_shift), idx_mask);
+ idx.vec[1] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[1], idx_shift), idx_mask);
+ idx.vec[0] = __lasx_xvor_v(idx.vec[0], lasx_ext16_32(lasx_extracti128(idx_l, 0)));
+ idx.vec[1] = __lasx_xvor_v(idx.vec[1], lasx_ext16_32(lasx_extracti128(idx_l, 1)));
+
+ // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.
+ //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);
+ //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);
+ const __m256i q2_1 = lasx_set_w(
+ iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],
+ iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]
+ );
+ const __m256i q2_2 = lasx_set_w(
+ iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],
+ iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]
+ );
+
+ __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | (signs[1] << 16));
+ aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
+ const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2);
+ const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1);
+
+ aux256 = __lasx_xvreplgr2vr_w(signs[2] | (signs[3] << 16));
+ aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
+ const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2);
+ const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2);
+
+ signs += 4;
+
+ const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
+ const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2);
+ const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
+ const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
+ const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
+ const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
+ sumi1 = __lasx_xvadd_w(sumi1, p1);
+ sumi2 = __lasx_xvadd_w(sumi2, p2);
+ }
+
+ accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
+ }
+
+ *s = hsum_float_8(accumf);
+
+#else
+
+ float sumf = 0.f;
+ for (int i = 0; i < nb; ++i) {
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+ const uint8_t * restrict qs = x[i].qs;
+ const uint8_t * restrict qh = x[i].qh;
+ const uint8_t * restrict signs = x[i].signs;
+ const int8_t * restrict q8 = y[i].qs;
+ int32_t bsum = 0;
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+ const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
+ const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1;
+ int32_t sumi = 0;
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
+ for (int j = 0; j < 4; ++j) {
+ sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
+ sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
+ }
+ q8 += 8;
+ }
+ qs += 8;
+ signs += 4;
+ bsum += sumi * ls1;
+ sumi = 0;
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
+ for (int j = 0; j < 4; ++j) {
+ sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
+ sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
+ }
+ q8 += 8;
+ }
+ qs += 8;
+ signs += 4;
+ bsum += sumi * ls2;
+ }
+ sumf += d * bsum;
+ }
+ *s = sumf;
+#endif
+}
+
+
+#if defined(__AVX__)
+static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
+ const __m128i ax = _mm_sign_epi8(x, x);
+ const __m128i sy = _mm_sign_epi8(y, x);
+ return _mm_maddubs_epi16(ax, sy);
+}
+#endif
+
+#if defined(__AVX2__)
+static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
+ const __m256i ax = _mm256_sign_epi8(x, x);
+ const __m256i sy = _mm256_sign_epi8(y, x);
+ return _mm256_maddubs_epi16(ax, sy);
+}
+#elif defined(__loongarch_asx)
+static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
+ const __m256i ax = __lasx_xvsigncov_b(x, x);
+ const __m256i sy = __lasx_xvsigncov_b(x, y);
+ __m256i tmp1, tmp2, tmp3;
+ tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy);
+ tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy);
+ tmp3 = __lasx_xvadd_h(tmp1, tmp2);
+ return __lasx_xvsat_h(tmp3, 15);
+}
+#endif
+
+void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_iq1_s * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#if defined __ARM_NEON
+
+ ggml_int8x16x4_t q1b;
+ ggml_int8x16x4_t q8b;
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+
+ const int8_t * q8 = y[i].qs;
+ const uint8_t * qs = x[i].qs;
+ const uint16_t * qh = x[i].qh;
+
+ int sumi1 = 0, sumi2 = 0, sumi3 = 0;
+
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
+
+ q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))),
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700)))));
+ q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))),
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700)))));
+ q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))),
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700)))));
+ q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))),
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700)))));
+ qs += 8;
+
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+ const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]);
+ const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]);
+
+ const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
+ const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
+ sumi1 += vaddvq_s32(p1) * ls1;
+ sumi2 += vaddvq_s32(p2) * ls2;
+ sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1)
+ + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1);
+
+ }
+
+ sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3);
+ }
+
+ *s = sumf;
+
+#elif defined __AVX2__
+
+ __m256 accum = _mm256_setzero_ps();
+ float accum1 = 0;
+ for (int i = 0; i < nb; ++i) {
+
+ const int8_t * q8 = y[i].qs;
+ const uint8_t * qs = x[i].qs;
+ const uint16_t * qh = x[i].qh;
+
+ __m256i sumi = _mm256_setzero_si256();
+ int sumi1 = 0;
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
+ const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],
+ iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
+ const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)],
+ iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
+ qs += 8;
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+ const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
+ const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
+ const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
+ const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
+ const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1));
+ const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2));
+
+ sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2));
+ sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
+ + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
+ }
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum);
+ accum1 += d * sumi1;
+
+ }
+
+ *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
+
+#elif defined __AVX__
+ __m256 accum = _mm256_setzero_ps();
+ float accum1 = 0;
+ for (int i = 0; i < nb; ++i) {
+
+ const int8_t * q8 = y[i].qs;
+ const uint8_t * qs = x[i].qs;
+ const uint16_t * qh = x[i].qh;
+
+ __m128i sumi1_0 = _mm_setzero_si128();
+ __m128i sumi1_1 = _mm_setzero_si128();
+ int sumi1 = 0;
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
+ const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
+ const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]);
+ const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
+ const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]);
+ qs += 8;
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+
+ const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
+ const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
+ const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
+ const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
+ const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
+ const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1));
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1));
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2));
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2));
+
+ sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
+ sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
+ sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
+ + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
+ }
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum);
+ accum1 += d * sumi1;
+
+ }
+
+ *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
+
+#elif defined(__POWER9_VECTOR__)
+ const vector unsigned char v0 = vec_splats((unsigned char)0x0);
+ const vector unsigned short vsign = vec_splats((unsigned short)0x8000);
+
+ vector float vsumf0 = vec_splats(0.0f);
+ vector float vsumf1 = vec_splats(0.0f);
+ vector float vsumf2 = vec_splats(0.0f);
+ vector float vsumf3 = vec_splats(0.0f);
+
+ for (int i = 0; i < nb; ++i) {
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+ vector float vyd = vec_splats(y[i].d);
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector signed int vsumi0 = vec_splats((int32_t)0);
+ vector signed int vsumi1 = vec_splats((int32_t)0);
+ vector signed int vsumi2 = vec_splats((int32_t)0);
+ vector signed int vsumi3 = vec_splats((int32_t)0);
+ vector signed int vsumi8 = vec_splats((int32_t)0);
+
+ const uint8_t * restrict q1 = x[i].qs;
+ const uint16_t * restrict qh = x[i].qh;
+ const int8_t * restrict q8 = y[i].qs;
+ const int16_t * restrict qs = y[i].bsums;
+
+ for (int j = 0; j < QK_K/32; j += 2) {
+ __builtin_prefetch(q1, 0, 1);
+ __builtin_prefetch(qh, 0, 1);
+ __builtin_prefetch(q8, 0, 1);
+
+ vector signed long long aux64x2_0 = {*(const int64_t *)(iq1s_grid + (q1[0] | ((qh[0] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[1] | ((qh[0] << 5) & 0x700)))};
+ vector signed long long aux64x2_1 = {*(const int64_t *)(iq1s_grid + (q1[2] | ((qh[0] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[3] | ((qh[0] >> 1) & 0x700)))};
+ vector signed long long aux64x2_2 = {*(const int64_t *)(iq1s_grid + (q1[4] | ((qh[1] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[5] | ((qh[1] << 5) & 0x700)))};
+ vector signed long long aux64x2_3 = {*(const int64_t *)(iq1s_grid + (q1[6] | ((qh[1] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[7] | ((qh[1] >> 1) & 0x700)))};
+ q1 += 8;
+
+ vector signed char q1x0 = (vector signed char)aux64x2_0;
+ vector signed char q1x1 = (vector signed char)aux64x2_1;
+ vector signed char q1x2 = (vector signed char)aux64x2_2;
+ vector signed char q1x3 = (vector signed char)aux64x2_3;
+
+ vector signed char q8y0 = vec_xl( 0, q8);
+ vector signed char q8y1 = vec_xl(16, q8);
+ vector signed char q8y2 = vec_xl(32, q8);
+ vector signed char q8y3 = vec_xl(48, q8);
+ q8 += 64;
+
+ vector signed short qv0 = vec_add(vec_mule(q1x0, q8y0), vec_mulo(q1x0, q8y0));
+ vector signed short qv1 = vec_add(vec_mule(q1x1, q8y1), vec_mulo(q1x1, q8y1));
+ vector signed short qv2 = vec_add(vec_mule(q1x2, q8y2), vec_mulo(q1x2, q8y2));
+ vector signed short qv3 = vec_add(vec_mule(q1x3, q8y3), vec_mulo(q1x3, q8y3));
+
+ const uint16_t ls0 = (uint16_t)((qh[0] >> 12) & 7);
+ const uint16_t ls1 = (uint16_t)((qh[1] >> 12) & 7);
+
+ vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));
+ vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));
+ vector signed short vscales = vec_sld(vscales23, vscales01, 8);
+
+ vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+ vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+ vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+ vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+
+ vector signed short q8ysums = vec_xl_len(qs, 8);
+ qs += 4;
+ q8ysums = vec_mergeh(q8ysums, (vector signed short)v0);
+
+ vector signed short qxh = (vector signed short)vec_sld(vec_splats(qh[1]), vec_splats(qh[0]), 8);
+ qh += 2;
+ vector __bool short vsel = vec_cmpge(qxh, (vector signed short)v0);
+
+ vector signed short q8ysum = vec_sel((vector signed short)vec_xor((vector unsigned short)q8ysums, vsign), q8ysums, vsel);
+
+ vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8);
+ }
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+ vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+ vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+
+ vsumf0 = vec_madd(vec_ctf(vsumi8, 0), vec_mul(vd, vec_splats(IQ1S_DELTA)), vsumf0);
+ }
+
+ vsumf0 = vec_add(vsumf0, vsumf2);
+ vsumf1 = vec_add(vsumf1, vsumf3);
+
+ vsumf0 = vec_add(vsumf0, vsumf1);
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ *s = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+ __m256 accum = (__m256)__lasx_xvldi(0);
+ float accum1 = 0;
+ for (int i = 0; i < nb; ++i) {
+
+ const int8_t * q8 = y[i].qs;
+ const uint8_t * qs = x[i].qs;
+ const uint16_t * qh = x[i].qh;
+
+ __m256i sumi = __lasx_xvldi(0);
+ int sumi1 = 0;
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
+ __m256i q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)], 0);
+ q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], 1);
+ q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], 2);
+ q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], 3);
+
+ __m256i q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)], 0);
+ q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], 1);
+ q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], 2);
+ q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], 3);
+
+ qs += 8;
+ const __m256i q8b_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+ const __m256i q8b_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+
+ const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
+ const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
+ const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
+ const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
+
+ __m256i tmp1, tmp5, tmp6;
+ tmp1 = __lasx_xvreplgr2vr_h(ls1);
+ tmp5 = __lasx_xvmulwev_w_h(dot1, tmp1);
+ tmp6 = __lasx_xvmulwod_w_h(dot1, tmp1);
+ const __m256i p1 = __lasx_xvadd_w(tmp5, tmp6);
+
+ tmp1 = __lasx_xvreplgr2vr_h(ls2);
+ tmp5 = __lasx_xvmulwev_w_h(dot2, tmp1);
+ tmp6 = __lasx_xvmulwod_w_h(dot2, tmp1);
+ const __m256i p2 = __lasx_xvadd_w(tmp5, tmp6);
+
+ sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p1, p2));
+ sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
+ + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
+ }
+
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+ accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), accum);
+ accum1 += d * sumi1;
+ }
+
+ *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
+
+#else
+
+ float sumf = 0;
+ for (int i = 0; i < nb; i++) {
+
+ const int8_t * q8 = y[i].qs;
+ const uint8_t * qs = x[i].qs;
+ const uint16_t * qh = x[i].qh;
+
+ int sumi = 0, sumi1 = 0;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ const int ls = 2*((qh[ib] >> 12) & 7) + 1;
+ const int delta = qh[ib] & 0x8000 ? -1 : 1;
+ int lsum = 0;
+ for (int l = 0; l < 4; ++l) {
+ const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
+ for (int j = 0; j < 8; ++j) {
+ lsum += q8[j] * grid[j];
+ }
+ q8 += 8;
+ }
+ sumi += ls * lsum;
+ sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]);
+ qs += 4;
+ }
+
+ sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
+ }
+
+ *s = sumf;
+
+#endif
+}
+
+void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(n % QK_K == 0);
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ const block_iq1_m * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+ iq1m_scale_t scale;
+
+#if defined __ARM_NEON
+ const int32x4_t mask = vdupq_n_s32(0x7);
+ const int32x4_t mone = vdupq_n_s32(1);
+ const int32x4_t mzero = vdupq_n_s32(0);
+
+ ggml_int8x16x4_t deltas;
+ deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1));
+ deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1));
+ deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1));
+ deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1));
+
+ ggml_int8x16x4_t q1b;
+ ggml_int8x16x4_t q8b;
+
+ uint32_t aux32;
+ const uint8_t * aux8 = (const uint8_t *)&aux32;
+
+ float sumf = 0;
+ for (int i = 0; i < nb; ++i) {
+
+ const int8_t * q8 = y[i].qs;
+ const uint8_t * qs = x[i].qs;
+ const uint8_t * qh = x[i].qh;
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+ int32x4_t sumi1 = mzero;
+ int32x4_t sumi2 = mzero;
+
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
+
+ q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))),
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700)))));
+ q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))),
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700)))));
+ q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))),
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700)))));
+ q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))),
+ vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700)))));
+
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+ const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1]));
+ const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3]));
+ const int32x4_t p12 = vpaddq_s32(p1, p2);
+
+ const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that
+ aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202);
+
+ const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1]));
+ const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3]));
+ const int32x4_t p34 = vpaddq_s32(p3, p4);
+
+ int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9);
+
+ scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);
+
+ sumi1 = vmlaq_s32(sumi1, scales_4, p12);
+ sumi2 = vmlaq_s32(sumi2, scales_4, p34);
+
+ qs += 8; qh += 4;
+
+ }
+
+ sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
+ }
+
+ *s = sumf;
+
+#elif defined __AVX2__
+
+ const __m256i mask = _mm256_set1_epi16(0x7);
+ const __m256i mone = _mm256_set1_epi16(1);
+
+ __m256 accum1 = _mm256_setzero_ps();
+ __m256 accum2 = _mm256_setzero_ps();
+ for (int i = 0; i < nb; ++i) {
+
+ const int8_t * q8 = y[i].qs;
+ const uint8_t * qs = x[i].qs;
+ const uint8_t * qh = x[i].qh;
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+ __m256i sumi1 = _mm256_setzero_si256();
+ __m256i sumi2 = _mm256_setzero_si256();
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
+ const __m256i q1b_1 = _mm256_set_epi64x(
+ iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
+ iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]
+ );
+ const __m256i q1b_2 = _mm256_set_epi64x(
+ iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
+ iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]
+ );
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+ const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
+ const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
+
+ const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+ qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
+ qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+ qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+ const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+ qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
+ qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+ qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+
+ const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
+ const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
+
+ __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0));
+ __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6));
+
+ scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone);
+ scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone);
+ const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
+ const __m256i p2 = _mm256_madd_epi16(dot2, scale2);
+ const __m256i p3 = _mm256_madd_epi16(dot3, scale1);
+ const __m256i p4 = _mm256_madd_epi16(dot4, scale2);
+
+ sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2));
+ sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4));
+
+ qs += 8; qh += 4;
+ }
+
+ const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
+
+ accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
+ accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
+ }
+
+ *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
+
+#elif defined __AVX__
+ const __m128i mask = _mm_set1_epi16(0x7);
+ const __m128i mone = _mm_set1_epi16(1);
+
+ __m256 accum1 = _mm256_setzero_ps();
+ __m256 accum2 = _mm256_setzero_ps();
+ for (int i = 0; i < nb; ++i) {
+
+ const int8_t * q8 = y[i].qs;
+ const uint8_t * qs = x[i].qs;
+ const uint8_t * qh = x[i].qh;
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+ __m128i sumi1_0 = _mm_setzero_si128();
+ __m128i sumi1_1 = _mm_setzero_si128();
+ __m128i sumi2_0 = _mm_setzero_si128();
+ __m128i sumi2_1 = _mm_setzero_si128();
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
+ const __m128i q1b_1_0 = _mm_set_epi64x(
+ iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]);
+ const __m128i q1b_1_1 = _mm_set_epi64x(
+ iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]);
+ const __m128i q1b_2_0 = _mm_set_epi64x(
+ iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]);
+ const __m128i q1b_2_1 = _mm_set_epi64x(
+ iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]);
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+
+ const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
+ const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
+ const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
+ const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
+
+ const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+ qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+ const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+ qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+ const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+ qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+ const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+ qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+
+ const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0);
+ const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1);
+ const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0);
+ const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1);
+
+ __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0);
+ __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3);
+ __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6);
+ __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9);
+
+ scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone);
+ scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone);
+ scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone);
+ scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone);
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0);
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1);
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0);
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1);
+ const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0);
+ const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1);
+ const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0);
+ const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1);
+
+ sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
+ sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
+ sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0));
+ sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1));
+
+ qs += 8; qh += 4;
+ }
+
+ const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
+
+ accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1);
+ accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2);
+ }
+
+ *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
+
+#else
+
+ int sum1[2], sum2[2], delta[4];
+
+ float sumf = 0;
+ for (int i = 0; i < nb; i++) {
+
+ const int8_t * q8 = y[i].qs;
+ const uint8_t * qs = x[i].qs;
+ const uint8_t * qh = x[i].qh;
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+ int sumi1 = 0, sumi2 = 0;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ delta[0] = qh[0] & 0x08 ? -1 : 1;
+ delta[1] = qh[0] & 0x80 ? -1 : 1;
+ delta[2] = qh[1] & 0x08 ? -1 : 1;
+ delta[3] = qh[1] & 0x80 ? -1 : 1;
+ sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0;
+ for (int l = 0; l < 4; ++l) {
+ const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700)));
+ int lsum1 = 0, lsum2 = 0;
+ for (int j = 0; j < 8; ++j) {
+ lsum1 += q8[j] * grid[j];
+ lsum2 += q8[j];
+ }
+ q8 += 8;
+ sum1[l/2] += lsum1;
+ sum2[l/2] += lsum2*delta[l];
+ }
+
+ const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
+ const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
+
+ sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
+ sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
+ qs += 4;
+ qh += 2;
+ }
+
+ sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
+ }
+
+ *s = sumf;
+
+#endif
+}
+
+void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_IQ4_NL, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
+ return;
+ }
+#endif
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+ assert(n % QK4_NL == 0);
+ static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
+
+ const block_iq4_nl * restrict x = vx;
+ const block_q8_0 * restrict y = vy;
+
+ const int nb = n / QK4_NL;
+
+ int ib = 0;
+ float sumf = 0;
+
+#if defined __ARM_NEON
+ const int8x16_t values = vld1q_s8(kvalues_iq4nl);
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
+ uint8x16x2_t q4bits;
+ int8x16x4_t q4b;
+ int8x16x4_t q8b;
+ int32x4_t prod_1, prod_2;
+
+ for (; ib + 1 < nb; ib += 2) {
+
+ q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
+ q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
+ q8b.val[0] = vld1q_s8(y[ib + 0].qs);
+ q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
+ q8b.val[2] = vld1q_s8(y[ib + 1].qs);
+ q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);
+
+ q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
+ q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
+ q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
+ q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
+
+ prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
+ prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
+
+ sumf +=
+ GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
+ GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
+ }
+
+#elif defined __AVX2__
+
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+ const __m128i m4b = _mm_set1_epi8(0x0f);
+ const __m256i mone = _mm256_set1_epi16(1);
+
+ __m256 accum1 = _mm256_setzero_ps();
+ __m256 accum2 = _mm256_setzero_ps();
+ for (; ib + 1 < nb; ib += 2) {
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
+ const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
+ const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+ const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
+ const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
+ accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
+ _mm256_cvtepi32_ps(p_1), accum1);
+ accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
+ _mm256_cvtepi32_ps(p_2), accum2);
+ }
+
+ sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
+
+#elif defined __AVX__
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+ const __m128i m4b = _mm_set1_epi8(0x0f);
+ const __m128i mone = _mm_set1_epi16(1);
+
+ __m256 accum1 = _mm256_setzero_ps();
+ __m256 accum2 = _mm256_setzero_ps();
+ for (; ib + 1 < nb; ib += 2) {
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
+
+ const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
+ const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
+ const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
+ const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
+ const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
+ const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
+ const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
+ const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
+ const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
+ const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
+ const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
+ const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
+ accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
+ _mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1);
+ accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
+ _mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2);
+ }
+
+ sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
+
+#elif defined(__POWER9_VECTOR__)
+ const vector signed char lowMask = vec_splats((signed char)0xF);
+ const vector signed int v0 = vec_splats((int32_t)0);
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+ vector float vsumf0 = vec_splats(0.0f);
+ vector float vsumf1 = vec_splats(0.0f);
+
+ const vector signed char values = vec_xl( 0, kvalues_iq4nl);
+
+#pragma GCC unroll 4
+ for (; ib < nb; ++ib) {
+ __builtin_prefetch(x[ib].qs, 0, 1);
+ __builtin_prefetch(y[ib].qs, 0, 1);
+
+
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+ vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+ vector signed char q4x0 = vec_and(qxs, lowMask);
+ vector signed char q4x1 = vec_sr(qxs, v4);
+
+ q4x0 = vec_perm(values, values, (vector unsigned char)q4x0);
+ q4x1 = vec_perm(values, values, (vector unsigned char)q4x1);
+
+ vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+ vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+ vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));
+ vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));
+
+ vector signed int vsumi0 = v0;
+ vector signed int vsumi1 = v0;
+
+ vsumi0 = vec_sum4s(qv0, vsumi0);
+ vsumi1 = vec_sum4s(qv1, vsumi1);
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+ }
+
+ vsumf0 = vec_add(vsumf0, vsumf1);
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ sumf = vec_extract(vsumf0, 0);
+
+#elif defined (__loongarch_asx)
+
+ const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
+ const __m128i m4b = __lsx_vreplgr2vr_b(0x0f);
+ const __m256i mone = __lasx_xvreplgr2vr_h(1);
+
+ __m256 accum1 = (__m256)__lasx_xvldi(0);
+ __m256 accum2 = (__m256)__lasx_xvldi(0);
+ for (; ib + 1 < nb; ib += 2) {
+ const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0);
+ const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0);
+ const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0);
+ const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0);
+ const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)),
+ lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b)));
+ const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)),
+ lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b)));
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+ const __m256i p_1 = lasx_madd_h(p16_1, mone);
+ const __m256i p_2 = lasx_madd_h(p16_2, mone);
+ accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
+ __lasx_xvffint_s_w(p_1), accum1);
+ accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
+ __lasx_xvffint_s_w(p_2), accum2);
+ }
+
+ sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
+
+#endif
+ for (; ib < nb; ++ib) {
+ const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
+ int sumi1 = 0, sumi2 = 0;
+ for (int j = 0; j < QK4_NL/2; ++j) {
+ sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
+ sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4];
+ }
+ sumf += d * (sumi1 + sumi2);
+ }
+ *s = sumf;
+}
+
+void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+ assert(n % QK_K == 0);
+
+ const block_iq4_xs * restrict x = vx;
+ const block_q8_K * restrict y = vy;
+
+ const int nb = n / QK_K;
+
+#if defined __ARM_NEON
+ const int8x16_t values = vld1q_s8(kvalues_iq4nl);
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
+ ggml_uint8x16x2_t q4bits;
+ ggml_int8x16x4_t q4b;
+ ggml_int8x16x4_t q8b;
+ int32x4_t prod_1, prod_2;
+
+ float sumf = 0;
+
+ for (int ibl = 0; ibl < nb; ++ibl) {
+
+ const int8_t * q8 = y[ibl].qs;
+ const uint8_t * q4 = x[ibl].qs;
+ uint16_t h = x[ibl].scales_h;
+
+ int sumi1 = 0, sumi2 = 0;
+ for (int ib = 0; ib < QK_K/64; ++ib) {
+
+ q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+ q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
+ q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
+ q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
+ q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
+
+ prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
+ prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
+
+ int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
+ int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
+ h >>= 4;
+ sumi1 += vaddvq_s32(prod_1) * ls1;
+ sumi2 += vaddvq_s32(prod_2) * ls2;
+
+ }
+
+ sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
+ }
+
+ *s = sumf;
+
+#elif defined __AVX2__
+
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+ const __m128i m4b = _mm_set1_epi8(0x0f);
+
+ __m256 accum = _mm256_setzero_ps();
+ for (int ibl = 0; ibl < nb; ++ibl) {
+ const uint8_t * qs = x[ibl].qs;
+ const int8_t * q8 = y[ibl].qs;
+ uint16_t sh = x[ibl].scales_h;
+ __m256i sumi1 = _mm256_setzero_si256();
+ __m256i sumi2 = _mm256_setzero_si256();
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs); qs += 16;
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16;
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+ const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
+ const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+ const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
+ const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
+ sh >>= 4;
+ const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1));
+ const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2));
+ sumi1 = _mm256_add_epi32(p_1, sumi1);
+ sumi2 = _mm256_add_epi32(p_2, sumi2);
+ }
+ accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
+ _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum);
+ }
+
+ *s = hsum_float_8(accum);
+
+#elif defined __AVX__
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+ const __m128i m4b = _mm_set1_epi8(0x0f);
+
+ __m256 accum = _mm256_setzero_ps();
+ for (int ibl = 0; ibl < nb; ++ibl) {
+ const uint8_t * qs = x[ibl].qs;
+ const int8_t * q8 = y[ibl].qs;
+ uint16_t sh = x[ibl].scales_h;
+ __m128i sumi1_0 = _mm_setzero_si128();
+ __m128i sumi1_1 = _mm_setzero_si128();
+ __m128i sumi2_0 = _mm_setzero_si128();
+ __m128i sumi2_1 = _mm_setzero_si128();
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+ const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
+ const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
+ const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
+ const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
+ const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
+ const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
+ const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
+ const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
+ const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
+ const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
+ sh >>= 4;
+ const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1));
+ const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1));
+ const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2));
+ const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2));
+ sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0);
+ sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1);
+ sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0);
+ sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1);
+ }
+ __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0);
+ __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1);
+ accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
+ _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum);
+ }
+
+ *s = hsum_float_8(accum);
+
+#elif defined(__POWER9_VECTOR__)
+ const vector signed char lowMask = vec_splats((signed char)0xF);
+ const vector int v0 = vec_splats((int32_t)0);
+ const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+ vector float vsumf0 = vec_splats(0.0f);
+ vector float vsumf1 = vec_splats(0.0f);
+ vector float vsumf2 = vec_splats(0.0f);
+ vector float vsumf3 = vec_splats(0.0f);
+
+ const vector signed char values = vec_xl( 0, kvalues_iq4nl);
+
+ for (int ibl = 0; ibl < nb; ++ibl) {
+
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ibl].d));
+ vector float vyd = vec_splats(y[ibl].d);
+ vector float vd = vec_mul(vxd, vyd);
+
+ vector signed int vsumi0 = v0;
+ vector signed int vsumi1 = v0;
+ vector signed int vsumi2 = v0;
+ vector signed int vsumi3 = v0;
+
+ uint16_t h = x[ibl].scales_h;
+
+ const uint8_t * restrict q4 = x[ibl].qs;
+ const uint8_t * restrict sc = x[ibl].scales_l;
+ const int8_t * restrict q8 = y[ibl].qs;
+
+ for (int ib = 0; ib < QK_K/64; ib ++ ) {
+ __builtin_prefetch(q4, 0, 1);
+ __builtin_prefetch(q8, 0, 1);
+
+ vector signed char qxs0 = (vector signed char)vec_xl( 0, q4);
+ vector signed char qxs1 = (vector signed char)vec_xl(16, q4);
+ q4 += 32;
+
+ vector signed char q4x00 = (vector signed char)vec_and(qxs0, lowMask);
+ vector signed char q4x01 = (vector signed char)vec_sr(qxs0, v4);
+ vector signed char q4x10 = (vector signed char)vec_and(qxs1, lowMask);
+ vector signed char q4x11 = (vector signed char)vec_sr(qxs1, v4);
+
+ q4x00 = vec_perm(values, values, (vector unsigned char)q4x00);
+ q4x01 = vec_perm(values, values, (vector unsigned char)q4x01);
+ q4x10 = vec_perm(values, values, (vector unsigned char)q4x10);
+ q4x11 = vec_perm(values, values, (vector unsigned char)q4x11);
+
+ vector signed char q8y0 = vec_xl( 0, q8);
+ vector signed char q8y1 = vec_xl(16, q8);
+ vector signed char q8y2 = vec_xl(32, q8);
+ vector signed char q8y3 = vec_xl(48, q8);
+ q8 += 64;
+
+ vector signed short qv0 = vec_add(vec_mule(q4x00, q8y0), vec_mulo(q4x00, q8y0));
+ vector signed short qv1 = vec_add(vec_mule(q4x01, q8y1), vec_mulo(q4x01, q8y1));
+ vector signed short qv2 = vec_add(vec_mule(q4x10, q8y2), vec_mulo(q4x10, q8y2));
+ vector signed short qv3 = vec_add(vec_mule(q4x11, q8y3), vec_mulo(q4x11, q8y3));
+
+ const uint16_t ls0 = (uint16_t)(((sc[0] & 0xf) | ((h << 4) & 0x30)) - 32);
+ const uint16_t ls1 = (uint16_t)(((sc[0] >> 4) | ((h << 2) & 0x30)) - 32);
+ h >>= 4;
+ sc ++;
+
+ vector signed short vscales01 = vec_splats((int16_t)ls0);
+ vector signed short vscales23 = vec_splats((int16_t)ls1);
+
+ vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+ vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+ vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+ vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+ }
+
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+ vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+ vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+ vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+ }
+
+ vsumf0 = vec_add(vsumf0, vsumf2);
+ vsumf1 = vec_add(vsumf1, vsumf3);
+
+ vsumf0 = vec_add(vsumf0, vsumf1);
+
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+ *s = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+ const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
+ const __m128i m4b = __lsx_vreplgr2vr_b(0x0f);
+
+ __m256 accum = (__m256)__lasx_xvldi(0);
+ __m256i tmp1;
+ __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask;
+
+ mask_8f = __lsx_vreplgr2vr_b(0x8f);
+ for (int ibl = 0; ibl < nb; ++ibl) {
+ const uint8_t * qs = x[ibl].qs;
+ const int8_t * q8 = y[ibl].qs;
+ uint16_t sh = x[ibl].scales_h;
+ __m256i sumi1 = __lasx_xvldi(0);
+ __m256i sumi2 = __lasx_xvldi(0);
+ __m128i zero = __lsx_vldi(0);
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
+ const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
+ const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
+ const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+ const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+ tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f);
+ tmp0 = __lsx_vori_b(tmp2, 0x10);
+ mask = __lsx_vsle_b(zero, tmp2);
+ tmp3 = __lsx_vand_v(tmp0, mask);
+ tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
+
+ tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f);
+ tmp0 = __lsx_vori_b(tmp2, 0x10);
+ mask = __lsx_vsle_b(zero, tmp2);
+ tmp4 = __lsx_vand_v(tmp0, mask);
+ tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
+
+ const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4);
+
+ tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f);
+ tmp0 = __lsx_vori_b(tmp2, 0x10);
+ mask = __lsx_vsle_b(zero, tmp2);
+ tmp3 = __lsx_vand_v(tmp0, mask);
+ tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
+
+ tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f);
+ tmp0 = __lsx_vori_b(tmp2, 0x10);
+ mask = __lsx_vsle_b(zero, tmp2);
+ tmp4 = __lsx_vand_v(tmp0, mask);
+ tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
+
+ const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4);
+
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+ const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
+ const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
+ sh >>= 4;
+ __m256i tmp5, tmp6;
+ tmp1 = __lasx_xvreplgr2vr_h(ls1);
+ tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1);
+ tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1);
+ const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6);
+ tmp1 = __lasx_xvreplgr2vr_h(ls2);
+ tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1);
+ tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1);
+ const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6);
+ sumi1 = __lasx_xvadd_w(p_1, sumi1);
+ sumi2 = __lasx_xvadd_w(p_2, sumi2);
+ }
+ accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
+ __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accum);
+ }
+
+ *s = hsum_float_8(accum);
+
+#else
+ float sumf = 0;
+ for (int ibl = 0; ibl < nb; ++ibl) {
+ const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
+ uint16_t h = x[ibl].scales_h;
+ const uint8_t * qs = x[ibl].qs;
+ const int8_t * q8 = y[ibl].qs;
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
+ const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
+ const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30);
+ h >>= 4;
+ const float d1 = d4d8*(ls1 - 32);
+ const float d2 = d4d8*(ls2 - 32);
+ int sumi1 = 0, sumi2 = 0;
+ for (int j = 0; j < 16; ++j) {
+ sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
+ sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
+ }
+ sumf += d1 * (sumi1 + sumi2);
+ qs += 16;
+ q8 += 32;
+ sumi1 = sumi2 = 0;
+ for (int j = 0; j < 16; ++j) {
+ sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
+ sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
+ }
+ sumf += d2 * (sumi1 + sumi2);
+ qs += 16;
+ q8 += 32;
+ }
+ }
+ *s = sumf;
+#endif
+}
+
+// ================================ IQ2 quantization =============================================
+
+typedef struct {
+ uint64_t * grid;
+ int * map;
+ uint16_t * neighbours;
+} iq2_entry_t;
+
+static iq2_entry_t iq2_data[4] = {
+ {NULL, NULL, NULL},
+ {NULL, NULL, NULL},
+ {NULL, NULL, NULL},
+ {NULL, NULL, NULL},
+};
+
+static inline int iq2_data_index(enum ggml_type type) {
+ GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);
+ return type == GGML_TYPE_IQ2_XXS ? 0 :
+ type == GGML_TYPE_IQ2_XS ? 1 :
+ type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 2 : 3;
+}
+
+static inline int iq2_grid_size(enum ggml_type type) {
+ GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);
+ return type == GGML_TYPE_IQ2_XXS ? 256 :
+ type == GGML_TYPE_IQ2_XS ? 512 :
+ type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? NGRID_IQ1S : 1024;
+}
+
+static int iq2_compare_func(const void * left, const void * right) {
+ const int * l = (const int *)left;
+ const int * r = (const int *)right;
+ return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
+}
+
+void iq2xs_init_impl(enum ggml_type type) {
+ const int gindex = iq2_data_index(type);
+ const int grid_size = iq2_grid_size(type);
+ if (iq2_data[gindex].grid) {
+ return;
+ }
+ static const uint16_t kgrid_2bit_256[256] = {
+ 0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97,
+ 100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642,
+ 1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288,
+ 1312, 1350, 1385, 1408, 1425, 1545, 1552, 1600, 1668, 1700, 2048, 2053, 2056, 2068, 2088, 2113,
+ 2116, 2128, 2130, 2184, 2308, 2368, 2562, 2580, 4097, 4100, 4112, 4129, 4160, 4192, 4228, 4240,
+ 4245, 4352, 4360, 4384, 4432, 4442, 4480, 4644, 4677, 5120, 5128, 5152, 5157, 5193, 5248, 5400,
+ 5474, 5632, 5654, 6145, 6148, 6160, 6208, 6273, 6400, 6405, 6560, 6737, 8192, 8194, 8202, 8260,
+ 8289, 8320, 8322, 8489, 8520, 8704, 8706, 9217, 9220, 9232, 9280, 9302, 9472, 9537, 9572, 9872,
+ 10248, 10272, 10388, 10820, 16385, 16388, 16400, 16408, 16417, 16420, 16448, 16456, 16470, 16480, 16513, 16516,
+ 16528, 16640, 16672, 16737, 16768, 16773, 16897, 16912, 16968, 16982, 17000, 17408, 17416, 17440, 17536, 17561,
+ 17682, 17700, 17920, 18433, 18436, 18448, 18496, 18501, 18688, 18776, 18785, 18818, 19013, 19088, 20480, 20488,
+ 20497, 20505, 20512, 20608, 20616, 20740, 20802, 20900, 21137, 21648, 21650, 21770, 22017, 22100, 22528, 22545,
+ 22553, 22628, 22848, 23048, 24580, 24592, 24640, 24680, 24832, 24917, 25112, 25184, 25600, 25605, 25872, 25874,
+ 25988, 26690, 32768, 32770, 32778, 32833, 32898, 33028, 33048, 33088, 33297, 33793, 33796, 33808, 33813, 33856,
+ 33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
+ 37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
+ };
+ static const uint16_t kgrid_2bit_512[512] = {
+ 0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
+ 73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257,
+ 260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340,
+ 352, 360, 385, 388, 400, 512, 514, 517, 520, 529, 532, 544, 577, 580, 592, 597,
+ 640, 650, 1025, 1028, 1030, 1033, 1040, 1042, 1045, 1048, 1057, 1060, 1088, 1090, 1093, 1096,
+ 1105, 1108, 1110, 1120, 1153, 1156, 1168, 1280, 1282, 1285, 1288, 1297, 1300, 1312, 1345, 1348,
+ 1360, 1377, 1408, 1537, 1540, 1552, 1574, 1600, 1602, 1668, 2048, 2050, 2053, 2056, 2058, 2065,
+ 2068, 2080, 2085, 2113, 2116, 2128, 2136, 2176, 2208, 2218, 2305, 2308, 2320, 2368, 2433, 2441,
+ 2560, 2592, 2600, 2710, 2720, 4097, 4100, 4102, 4105, 4112, 4114, 4117, 4120, 4129, 4132, 4160,
+ 4162, 4165, 4168, 4177, 4180, 4192, 4202, 4225, 4228, 4240, 4352, 4354, 4357, 4360, 4369, 4372,
+ 4384, 4417, 4420, 4432, 4480, 4500, 4502, 4609, 4612, 4614, 4624, 4672, 4704, 5120, 5122, 5125,
+ 5128, 5137, 5140, 5152, 5185, 5188, 5193, 5200, 5220, 5248, 5377, 5380, 5392, 5440, 5632, 5652,
+ 5705, 6145, 6148, 6160, 6162, 6208, 6228, 6278, 6400, 6405, 6502, 6737, 6825, 8192, 8194, 8197,
+ 8200, 8202, 8209, 8212, 8224, 8257, 8260, 8272, 8320, 8352, 8449, 8452, 8464, 8512, 8520, 8549,
+ 8704, 8738, 8832, 8872, 9217, 9220, 9232, 9257, 9280, 9472, 9537, 9554, 9625, 9729, 9754, 9894,
+ 10240, 10248, 10250, 10272, 10325, 10376, 10402, 10600, 10640, 10760, 10784, 10882, 10888, 10890, 16385, 16388,
+ 16390, 16393, 16400, 16402, 16405, 16408, 16417, 16420, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16480,
+ 16485, 16513, 16516, 16528, 16640, 16642, 16645, 16648, 16657, 16660, 16672, 16705, 16708, 16720, 16768, 16773,
+ 16802, 16897, 16900, 16912, 16914, 16937, 16960, 17408, 17410, 17413, 17416, 17425, 17428, 17433, 17440, 17473,
+ 17476, 17488, 17536, 17556, 17665, 17668, 17680, 17700, 17728, 17818, 17920, 17930, 17988, 18000, 18433, 18436,
+ 18448, 18496, 18501, 18516, 18530, 18688, 18705, 18756, 18768, 18793, 18948, 20480, 20482, 20485, 20488, 20497,
+ 20500, 20512, 20520, 20545, 20548, 20560, 20608, 20737, 20740, 20752, 20757, 20800, 20802, 20992, 21060, 21162,
+ 21505, 21508, 21520, 21537, 21568, 21600, 21633, 21665, 21760, 21768, 21888, 21896, 22049, 22120, 22177, 22528,
+ 22548, 22593, 22608, 22681, 22810, 22848, 22850, 23173, 24577, 24580, 24592, 24640, 24660, 24674, 24710, 24745,
+ 24832, 25124, 25162, 25234, 25600, 25622, 25872, 25920, 25925, 26020, 26625, 26730, 26917, 27142, 27220, 27234,
+ 32768, 32770, 32773, 32776, 32785, 32788, 32800, 32810, 32833, 32836, 32848, 32896, 32898, 32936, 32938, 33025,
+ 33028, 33030, 33040, 33088, 33105, 33113, 33280, 33312, 33408, 33410, 33440, 33448, 33793, 33796, 33808, 33810,
+ 33813, 33856, 33888, 33929, 34048, 34116, 34213, 34328, 34410, 34816, 34824, 34853, 34906, 34944, 34946, 34984,
+ 35078, 35362, 35456, 35464, 35478, 35496, 36865, 36868, 36880, 36928, 36950, 36996, 37120, 37154, 37220, 37462,
+ 37513, 37888, 37893, 37956, 37968, 37976, 38185, 38288, 38290, 38465, 38993, 39078, 39241, 39445, 39520, 40960,
+ 40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
+ 42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
+ };
+ static const uint16_t kgrid_1bit_2048[NGRID_IQ1S] = {
+ 0, 2, 5, 8, 10, 17, 21, 32, 34, 40, 42, 69, 81, 84, 86, 101,
+ 128, 130, 136, 138, 149, 160, 162, 168, 170, 260, 261, 273, 276, 278, 281, 282,
+ 293, 321, 326, 329, 338, 341, 346, 353, 356, 358, 360, 389, 401, 404, 406, 421,
+ 512, 514, 520, 522, 533, 544, 546, 552, 554, 581, 593, 601, 612, 617, 640, 642,
+ 648, 650, 657, 661, 665, 672, 674, 680, 682, 1041, 1044, 1046, 1061, 1089, 1097, 1109,
+ 1114, 1124, 1125, 1169, 1177, 1189, 1281, 1284, 1285, 1286, 1301, 1304, 1306, 1321, 1344, 1349,
+ 1354, 1360, 1361, 1364, 1365, 1366, 1369, 1376, 1378, 1381, 1384, 1386, 1409, 1425, 1429, 1432,
+ 1434, 1441, 1444, 1445, 1446, 1449, 1556, 1561, 1601, 1604, 1616, 1618, 1621, 1624, 1632, 1633,
+ 1638, 1641, 1669, 1681, 1684, 1689, 2048, 2050, 2056, 2058, 2069, 2080, 2082, 2088, 2090, 2117,
+ 2129, 2134, 2149, 2176, 2178, 2184, 2186, 2197, 2208, 2210, 2216, 2218, 2309, 2321, 2324, 2329,
+ 2340, 2341, 2369, 2384, 2385, 2389, 2401, 2404, 2409, 2449, 2452, 2454, 2457, 2469, 2560, 2562,
+ 2568, 2570, 2581, 2592, 2594, 2600, 2602, 2629, 2641, 2649, 2657, 2661, 2688, 2690, 2693, 2696,
+ 2698, 2709, 2720, 2722, 2728, 2730, 4112, 4113, 4116, 4121, 4132, 4133, 4161, 4164, 4176, 4181,
+ 4184, 4193, 4196, 4197, 4201, 4241, 4244, 4246, 4257, 4261, 4353, 4356, 4358, 4361, 4368, 4370,
+ 4373, 4376, 4385, 4388, 4393, 4421, 4426, 4432, 4433, 4434, 4436, 4437, 4438, 4441, 4448, 4453,
+ 4484, 4498, 4501, 4513, 4516, 4625, 4628, 4630, 4645, 4672, 4678, 4681, 4690, 4693, 4696, 4698,
+ 4708, 4710, 4741, 4753, 4756, 4758, 4773, 5121, 5126, 5129, 5140, 5141, 5144, 5145, 5153, 5158,
+ 5185, 5189, 5190, 5192, 5194, 5201, 5204, 5205, 5206, 5209, 5218, 5221, 5224, 5252, 5257, 5264,
+ 5268, 5269, 5272, 5273, 5274, 5281, 5284, 5285, 5289, 5378, 5381, 5386, 5393, 5396, 5397, 5398,
+ 5401, 5408, 5410, 5413, 5416, 5418, 5441, 5444, 5445, 5446, 5457, 5458, 5460, 5461, 5462, 5465,
+ 5466, 5473, 5476, 5477, 5478, 5481, 5504, 5506, 5508, 5509, 5512, 5514, 5520, 5521, 5524, 5525,
+ 5526, 5529, 5530, 5536, 5538, 5541, 5633, 5636, 5637, 5638, 5653, 5654, 5656, 5658, 5665, 5670,
+ 5696, 5698, 5700, 5701, 5704, 5706, 5713, 5717, 5718, 5720, 5721, 5729, 5732, 5733, 5736, 5737,
+ 5738, 5766, 5770, 5778, 5781, 5796, 5801, 6161, 6166, 6181, 6209, 6212, 6214, 6217, 6224, 6229,
+ 6232, 6234, 6240, 6241, 6244, 6246, 6249, 6277, 6289, 6292, 6309, 6416, 6418, 6421, 6426, 6433,
+ 6437, 6466, 6468, 6469, 6472, 6481, 6484, 6485, 6486, 6489, 6490, 6496, 6501, 6506, 6537, 6545,
+ 6546, 6549, 6552, 6561, 6566, 6569, 6665, 6678, 6692, 6694, 6724, 6726, 6729, 6736, 6738, 6741,
+ 6744, 6753, 6758, 6761, 6789, 6801, 6806, 6810, 8192, 8194, 8200, 8202, 8213, 8224, 8226, 8229,
+ 8232, 8234, 8261, 8273, 8281, 8289, 8293, 8320, 8322, 8328, 8330, 8341, 8352, 8354, 8357, 8360,
+ 8362, 8453, 8465, 8468, 8473, 8485, 8514, 8516, 8521, 8533, 8536, 8538, 8545, 8548, 8549, 8550,
+ 8581, 8592, 8598, 8601, 8613, 8705, 8712, 8714, 8721, 8725, 8736, 8738, 8744, 8746, 8773, 8785,
+ 8790, 8793, 8805, 8833, 8840, 8842, 8849, 8853, 8864, 8866, 8872, 8874, 9221, 9236, 9238, 9241,
+ 9253, 9284, 9285, 9286, 9289, 9298, 9301, 9304, 9306, 9318, 9349, 9361, 9364, 9369, 9377, 9381,
+ 9481, 9493, 9505, 9513, 9536, 9541, 9544, 9553, 9556, 9557, 9561, 9570, 9573, 9576, 9609, 9616,
+ 9620, 9621, 9624, 9626, 9633, 9636, 9638, 9641, 9733, 9744, 9746, 9753, 9765, 9793, 9801, 9813,
+ 9824, 9825, 9833, 9860, 9862, 9872, 9882, 10240, 10242, 10248, 10250, 10261, 10272, 10274, 10280, 10282,
+ 10309, 10321, 10324, 10341, 10368, 10370, 10376, 10378, 10400, 10402, 10408, 10410, 10505, 10513, 10516, 10521,
+ 10533, 10566, 10569, 10578, 10581, 10593, 10596, 10598, 10601, 10629, 10640, 10646, 10649, 10660, 10661, 10752,
+ 10754, 10760, 10762, 10784, 10786, 10792, 10794, 10821, 10833, 10838, 10841, 10853, 10880, 10882, 10888, 10890,
+ 10901, 10912, 10914, 10920, 10922, 16389, 16401, 16406, 16421, 16457, 16466, 16469, 16472, 16474, 16481, 16484,
+ 16486, 16532, 16537, 16545, 16550, 16640, 16641, 16644, 16646, 16649, 16658, 16661, 16662, 16664, 16666, 16673,
+ 16678, 16681, 16709, 16712, 16714, 16721, 16724, 16725, 16726, 16729, 16730, 16741, 16744, 16746, 16769, 16772,
+ 16774, 16784, 16786, 16789, 16800, 16801, 16802, 16901, 16913, 16916, 16918, 16933, 16961, 16978, 16981, 16986,
+ 16996, 17001, 17033, 17044, 17061, 17409, 17429, 17433, 17449, 17477, 17480, 17482, 17489, 17492, 17493, 17494,
+ 17505, 17506, 17509, 17512, 17514, 17537, 17542, 17545, 17552, 17554, 17557, 17568, 17569, 17577, 17665, 17666,
+ 17669, 17674, 17681, 17684, 17685, 17686, 17689, 17696, 17701, 17706, 17729, 17732, 17733, 17734, 17737, 17744,
+ 17745, 17748, 17749, 17750, 17752, 17753, 17761, 17764, 17765, 17766, 17769, 17794, 17796, 17797, 17800, 17809,
+ 17812, 17813, 17814, 17817, 17818, 17829, 17832, 17834, 17921, 17925, 17929, 17940, 17941, 17944, 17946, 17953,
+ 17956, 17961, 17984, 17986, 17989, 17992, 18000, 18001, 18002, 18005, 18006, 18009, 18018, 18021, 18024, 18049,
+ 18053, 18058, 18068, 18069, 18081, 18084, 18086, 18437, 18449, 18453, 18458, 18469, 18498, 18505, 18512, 18517,
+ 18520, 18529, 18532, 18534, 18537, 18565, 18577, 18580, 18582, 18585, 18597, 18689, 18693, 18694, 18698, 18704,
+ 18708, 18709, 18712, 18721, 18724, 18726, 18752, 18757, 18762, 18769, 18770, 18772, 18773, 18774, 18777, 18784,
+ 18786, 18789, 18790, 18794, 18822, 18825, 18834, 18837, 18838, 18840, 18849, 18852, 18854, 18857, 18966, 19012,
+ 19014, 19017, 19029, 19032, 19034, 19044, 19049, 19092, 19109, 20481, 20484, 20485, 20486, 20489, 20498, 20501,
+ 20506, 20513, 20516, 20521, 20544, 20549, 20552, 20561, 20564, 20565, 20566, 20569, 20581, 20584, 20614, 20617,
+ 20629, 20632, 20640, 20641, 20646, 20649, 20741, 20744, 20745, 20746, 20753, 20756, 20757, 20758, 20760, 20761,
+ 20768, 20773, 20774, 20776, 20778, 20801, 20804, 20805, 20806, 20809, 20816, 20817, 20818, 20820, 20821, 20822,
+ 20824, 20825, 20826, 20833, 20836, 20837, 20838, 20841, 20866, 20869, 20881, 20884, 20885, 20886, 20889, 20896,
+ 20901, 20906, 20993, 20998, 21010, 21013, 21018, 21025, 21028, 21058, 21061, 21066, 21073, 21076, 21077, 21078,
+ 21081, 21090, 21093, 21125, 21136, 21138, 21141, 21145, 21146, 21156, 21508, 21509, 21521, 21524, 21525, 21526,
+ 21528, 21529, 21537, 21541, 21544, 21546, 21569, 21572, 21573, 21574, 21577, 21578, 21584, 21585, 21588, 21589,
+ 21590, 21592, 21593, 21594, 21601, 21602, 21604, 21605, 21606, 21609, 21632, 21640, 21642, 21649, 21652, 21653,
+ 21654, 21657, 21665, 21668, 21669, 21674, 21761, 21762, 21764, 21765, 21766, 21769, 21776, 21777, 21778, 21780,
+ 21781, 21782, 21785, 21786, 21793, 21796, 21797, 21798, 21801, 21824, 21825, 21826, 21828, 21829, 21830, 21832,
+ 21833, 21840, 21841, 21842, 21844, 21845, 21846, 21848, 21849, 21850, 21856, 21857, 21860, 21861, 21862, 21864,
+ 21865, 21866, 21889, 21892, 21893, 21897, 21898, 21904, 21905, 21908, 21909, 21910, 21912, 21913, 21921, 21924,
+ 21925, 21926, 21929, 22016, 22017, 22018, 22020, 22022, 22024, 22025, 22033, 22036, 22037, 22040, 22041, 22048,
+ 22049, 22050, 22052, 22053, 22054, 22056, 22057, 22081, 22085, 22086, 22088, 22089, 22090, 22096, 22097, 22098,
+ 22100, 22101, 22102, 22104, 22105, 22106, 22113, 22116, 22117, 22121, 22146, 22149, 22150, 22152, 22153, 22154,
+ 22161, 22165, 22170, 22178, 22181, 22182, 22184, 22185, 22532, 22533, 22534, 22537, 22544, 22549, 22552, 22561,
+ 22570, 22597, 22600, 22602, 22609, 22612, 22613, 22614, 22616, 22617, 22624, 22626, 22628, 22629, 22658, 22665,
+ 22672, 22674, 22677, 22680, 22689, 22697, 22785, 22786, 22789, 22794, 22801, 22804, 22805, 22806, 22809, 22821,
+ 22849, 22852, 22853, 22854, 22857, 22864, 22865, 22866, 22868, 22869, 22870, 22872, 22873, 22874, 22881, 22884,
+ 22885, 22886, 22889, 22913, 22917, 22921, 22929, 22932, 22933, 22934, 22936, 22937, 22949, 23044, 23048, 23061,
+ 23066, 23072, 23077, 23078, 23081, 23109, 23112, 23113, 23121, 23125, 23126, 23128, 23129, 23138, 23141, 23144,
+ 23146, 23169, 23178, 23186, 23189, 23190, 23192, 23194, 23201, 24581, 24596, 24598, 24601, 24613, 24644, 24656,
+ 24661, 24662, 24664, 24666, 24673, 24676, 24678, 24681, 24705, 24726, 24741, 24833, 24836, 24838, 24841, 24850,
+ 24853, 24865, 24866, 24870, 24873, 24901, 24905, 24913, 24917, 24918, 24921, 24933, 24934, 24938, 24964, 24970,
+ 24978, 24981, 24993, 24998, 25001, 25105, 25110, 25113, 25152, 25153, 25158, 25173, 25174, 25176, 25184, 25221,
+ 25233, 25238, 25253, 25617, 25618, 25621, 25622, 25626, 25633, 25638, 25641, 25664, 25666, 25669, 25672, 25674,
+ 25681, 25684, 25685, 25686, 25689, 25690, 25696, 25698, 25701, 25732, 25733, 25737, 25744, 25746, 25748, 25749,
+ 25750, 25752, 25754, 25761, 25764, 25769, 25861, 25864, 25866, 25873, 25877, 25878, 25881, 25924, 25925, 25926,
+ 25929, 25936, 25937, 25940, 25941, 25942, 25945, 25953, 25956, 25957, 25958, 25961, 25990, 25993, 25994, 26001,
+ 26005, 26006, 26009, 26010, 26018, 26021, 26022, 26024, 26114, 26121, 26133, 26144, 26150, 26152, 26153, 26176,
+ 26181, 26184, 26186, 26193, 26196, 26197, 26198, 26200, 26202, 26208, 26213, 26216, 26240, 26242, 26245, 26250,
+ 26260, 26262, 26264, 26265, 26272, 26276, 26278, 26282, 26646, 26649, 26661, 26689, 26706, 26709, 26714, 26721,
+ 26729, 26757, 26769, 26776, 26790, 26881, 26884, 26896, 26901, 26913, 26916, 26918, 26921, 26944, 26945, 26949,
+ 26950, 26952, 26961, 26964, 26965, 26966, 26969, 26976, 26981, 26986, 27010, 27012, 27018, 27029, 27041, 27044,
+ 27045, 27049, 27153, 27158, 27160, 27201, 27204, 27209, 27216, 27221, 27224, 27226, 27236, 27237, 27241, 27270,
+ 27284, 27288, 27290, 27302, 32768, 32770, 32776, 32778, 32800, 32802, 32808, 32810, 32837, 32848, 32849, 32852,
+ 32854, 32857, 32869, 32896, 32898, 32904, 32906, 32917, 32928, 32930, 32936, 32938, 33029, 33041, 33044, 33046,
+ 33049, 33061, 33089, 33092, 33097, 33104, 33106, 33109, 33110, 33112, 33113, 33124, 33126, 33129, 33157, 33161,
+ 33172, 33174, 33177, 33189, 33280, 33282, 33288, 33290, 33301, 33312, 33314, 33320, 33322, 33361, 33364, 33369,
+ 33381, 33408, 33410, 33416, 33418, 33429, 33440, 33442, 33448, 33450, 33812, 33817, 33857, 33860, 33873, 33877,
+ 33882, 33889, 33892, 33897, 33940, 33945, 34049, 34057, 34066, 34069, 34074, 34086, 34089, 34112, 34113, 34117,
+ 34120, 34129, 34132, 34133, 34134, 34137, 34138, 34149, 34150, 34152, 34154, 34177, 34180, 34182, 34185, 34192,
+ 34194, 34197, 34200, 34214, 34321, 34326, 34329, 34341, 34369, 34372, 34377, 34378, 34384, 34389, 34393, 34394,
+ 34401, 34406, 34410, 34437, 34449, 34458, 34468, 34816, 34818, 34824, 34826, 34837, 34848, 34850, 34856, 34858,
+ 34881, 34885, 34897, 34900, 34905, 34917, 34921, 34944, 34946, 34952, 34954, 34965, 34976, 34978, 34984, 34986,
+ 35077, 35078, 35089, 35092, 35094, 35109, 35137, 35140, 35142, 35145, 35152, 35154, 35157, 35162, 35169, 35172,
+ 35205, 35222, 35225, 35237, 35328, 35330, 35336, 35338, 35349, 35360, 35362, 35368, 35370, 35397, 35409, 35412,
+ 35414, 35456, 35458, 35464, 35466, 35477, 35488, 35490, 35496, 35498, 36869, 36881, 36886, 36888, 36889, 36901,
+ 36929, 36934, 36937, 36949, 36952, 36954, 36969, 36970, 36997, 37009, 37012, 37014, 37017, 37029, 37121, 37124,
+ 37126, 37129, 37136, 37141, 37144, 37146, 37153, 37156, 37158, 37161, 37184, 37189, 37200, 37201, 37204, 37205,
+ 37206, 37209, 37218, 37221, 37252, 37254, 37266, 37269, 37272, 37281, 37284, 37286, 37289, 37381, 37393, 37396,
+ 37401, 37413, 37444, 37446, 37449, 37456, 37458, 37461, 37464, 37478, 37481, 37509, 37524, 37526, 37545, 37889,
+ 37892, 37894, 37904, 37909, 37912, 37926, 37952, 37962, 37969, 37972, 37973, 37974, 37976, 37977, 37984, 37985,
+ 37986, 37989, 38020, 38022, 38034, 38036, 38037, 38040, 38049, 38057, 38144, 38149, 38152, 38154, 38160, 38161,
+ 38164, 38165, 38166, 38169, 38177, 38181, 38185, 38186, 38209, 38212, 38213, 38214, 38217, 38224, 38225, 38226,
+ 38228, 38229, 38230, 38232, 38233, 38234, 38241, 38244, 38245, 38246, 38249, 38273, 38277, 38280, 38289, 38290,
+ 38292, 38293, 38294, 38297, 38298, 38304, 38306, 38309, 38312, 38314, 38401, 38404, 38416, 38421, 38425, 38432,
+ 38438, 38441, 38469, 38472, 38473, 38481, 38482, 38485, 38486, 38489, 38501, 38504, 38530, 38532, 38537, 38538,
+ 38546, 38548, 38549, 38564, 38566, 38569, 38917, 38934, 38937, 38949, 38977, 38982, 38992, 38994, 38997, 38998,
+ 39002, 39012, 39013, 39045, 39057, 39062, 39065, 39077, 39172, 39174, 39177, 39184, 39186, 39189, 39192, 39194,
+ 39200, 39201, 39204, 39206, 39232, 39234, 39237, 39240, 39242, 39249, 39252, 39253, 39254, 39257, 39266, 39269,
+ 39270, 39274, 39297, 39300, 39312, 39314, 39317, 39322, 39329, 39334, 39429, 39445, 39461, 39492, 39494, 39497,
+ 39504, 39509, 39512, 39521, 39557, 39569, 39572, 39573, 39574, 40960, 40962, 40968, 40970, 40981, 40992, 40994,
+ 41000, 41002, 41029, 41041, 41044, 41046, 41049, 41088, 41090, 41096, 41098, 41109, 41120, 41122, 41128, 41130,
+ 41221, 41225, 41233, 41236, 41238, 41241, 41242, 41286, 41289, 41297, 41301, 41304, 41306, 41313, 41316, 41349,
+ 41360, 41362, 41366, 41369, 41474, 41480, 41482, 41488, 41497, 41506, 41512, 41514, 41541, 41553, 41558, 41561,
+ 41573, 41600, 41602, 41608, 41610, 41621, 41632, 41634, 41640, 41642, 42009, 42021, 42049, 42052, 42064, 42068,
+ 42069, 42072, 42074, 42081, 42085, 42086, 42088, 42089, 42117, 42246, 42249, 42256, 42258, 42261, 42264, 42278,
+ 42281, 42306, 42309, 42321, 42324, 42325, 42326, 42329, 42341, 42346, 42369, 42372, 42373, 42374, 42377, 42386,
+ 42389, 42392, 42501, 42513, 42518, 42522, 42529, 42533, 42564, 42566, 42570, 42578, 42581, 42582, 42584, 42592,
+ 42594, 42630, 42640, 42645, 42646, 42649, 42657, 42660, 42662, 43008, 43010, 43016, 43018, 43040, 43042, 43048,
+ 43050, 43089, 43092, 43094, 43097, 43136, 43138, 43144, 43146, 43157, 43168, 43170, 43176, 43178, 43269, 43284,
+ 43289, 43297, 43301, 43329, 43344, 43349, 43354, 43361, 43366, 43369, 43408, 43414, 43520, 43522, 43528, 43530,
+ 43552, 43554, 43560, 43562, 43601, 43604, 43606, 43648, 43650, 43656, 43658, 43669, 43680, 43682, 43688, 43690,
+ };
+ static const uint16_t kgrid_2bit_1024[1024] = {
+ 0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
+ 73, 80, 82, 85, 88, 97, 100, 102, 105, 128, 130, 133, 136, 145, 148, 160,
+ 165, 170, 257, 260, 262, 265, 272, 274, 277, 280, 289, 292, 320, 322, 325, 328,
+ 337, 340, 342, 345, 352, 357, 360, 385, 388, 400, 402, 405, 417, 420, 512, 514,
+ 517, 520, 529, 532, 544, 554, 577, 580, 582, 585, 592, 597, 640, 645, 650, 660,
+ 674, 1025, 1028, 1030, 1033, 1040, 1042, 1045, 1048, 1057, 1060, 1062, 1065, 1088, 1090, 1093,
+ 1096, 1098, 1105, 1108, 1110, 1113, 1120, 1122, 1125, 1153, 1156, 1158, 1161, 1168, 1173, 1176,
+ 1185, 1188, 1280, 1282, 1285, 1288, 1290, 1297, 1300, 1302, 1305, 1312, 1317, 1320, 1345, 1348,
+ 1350, 1353, 1360, 1362, 1365, 1368, 1377, 1380, 1408, 1410, 1413, 1416, 1425, 1428, 1440, 1537,
+ 1540, 1542, 1545, 1552, 1557, 1600, 1605, 1608, 1617, 1620, 1632, 1665, 1668, 1680, 2048, 2050,
+ 2053, 2056, 2065, 2068, 2070, 2073, 2080, 2085, 2090, 2113, 2116, 2118, 2121, 2128, 2130, 2133,
+ 2136, 2145, 2148, 2176, 2181, 2196, 2218, 2305, 2308, 2320, 2322, 2325, 2328, 2337, 2368, 2373,
+ 2376, 2385, 2388, 2400, 2433, 2448, 2560, 2577, 2580, 2594, 2600, 2602, 2640, 2713, 4097, 4100,
+ 4102, 4105, 4112, 4114, 4117, 4120, 4129, 4132, 4134, 4160, 4162, 4165, 4168, 4177, 4180, 4182,
+ 4185, 4192, 4194, 4197, 4200, 4225, 4228, 4230, 4240, 4245, 4248, 4257, 4260, 4352, 4354, 4357,
+ 4360, 4362, 4369, 4372, 4374, 4377, 4384, 4386, 4389, 4392, 4417, 4420, 4422, 4425, 4432, 4434,
+ 4437, 4440, 4449, 4452, 4480, 4482, 4485, 4488, 4497, 4500, 4609, 4612, 4617, 4624, 4629, 4641,
+ 4644, 4672, 4677, 4689, 4692, 4737, 4740, 4752, 5120, 5122, 5125, 5128, 5137, 5140, 5142, 5145,
+ 5152, 5157, 5160, 5185, 5188, 5190, 5193, 5200, 5202, 5205, 5208, 5217, 5220, 5248, 5250, 5253,
+ 5256, 5265, 5268, 5280, 5377, 5380, 5382, 5385, 5392, 5394, 5397, 5400, 5409, 5412, 5440, 5442,
+ 5445, 5448, 5457, 5460, 5472, 5505, 5508, 5520, 5632, 5637, 5640, 5649, 5652, 5664, 5697, 5700,
+ 5712, 5760, 5802, 6145, 6148, 6150, 6153, 6160, 6165, 6168, 6177, 6208, 6210, 6213, 6216, 6225,
+ 6228, 6240, 6273, 6276, 6400, 6402, 6405, 6408, 6417, 6420, 6432, 6465, 6468, 6480, 6505, 6562,
+ 6660, 6672, 6720, 6742, 8192, 8194, 8197, 8200, 8209, 8212, 8214, 8217, 8224, 8229, 8234, 8257,
+ 8260, 8272, 8274, 8277, 8292, 8320, 8330, 8340, 8362, 8449, 8452, 8464, 8466, 8469, 8481, 8512,
+ 8514, 8517, 8529, 8532, 8544, 8577, 8580, 8592, 8704, 8714, 8738, 8744, 8746, 8772, 8784, 8840,
+ 8842, 8872, 9217, 9220, 9222, 9225, 9232, 9237, 9240, 9249, 9252, 9280, 9282, 9285, 9288, 9297,
+ 9300, 9312, 9345, 9348, 9360, 9472, 9477, 9480, 9489, 9492, 9504, 9537, 9540, 9552, 9574, 9600,
+ 9729, 9732, 9744, 9792, 9817, 10240, 10245, 10257, 10260, 10305, 10308, 10320, 10378, 10410, 10497, 10500,
+ 10512, 10645, 10762, 10786, 10852, 10888, 10890, 16385, 16388, 16390, 16393, 16400, 16402, 16405, 16408, 16410,
+ 16417, 16420, 16422, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16470, 16473, 16480, 16482, 16485, 16513,
+ 16516, 16528, 16533, 16536, 16545, 16548, 16640, 16642, 16645, 16648, 16657, 16660, 16662, 16665, 16672, 16674,
+ 16677, 16705, 16708, 16710, 16713, 16720, 16722, 16725, 16728, 16737, 16740, 16768, 16770, 16773, 16776, 16785,
+ 16788, 16800, 16897, 16900, 16912, 16914, 16917, 16920, 16932, 16960, 16965, 16968, 16977, 16980, 16992, 17025,
+ 17028, 17408, 17410, 17413, 17416, 17418, 17425, 17428, 17430, 17433, 17440, 17442, 17445, 17448, 17473, 17476,
+ 17478, 17481, 17488, 17490, 17493, 17496, 17505, 17508, 17536, 17538, 17541, 17544, 17553, 17556, 17568, 17665,
+ 17668, 17670, 17673, 17680, 17682, 17685, 17688, 17697, 17700, 17728, 17730, 17733, 17736, 17745, 17748, 17760,
+ 17770, 17793, 17796, 17808, 17920, 17922, 17925, 17928, 17937, 17940, 17952, 17985, 17988, 18000, 18048, 18085,
+ 18433, 18436, 18441, 18448, 18450, 18453, 18456, 18465, 18468, 18496, 18498, 18501, 18504, 18513, 18516, 18528,
+ 18564, 18576, 18688, 18690, 18693, 18696, 18705, 18708, 18720, 18753, 18756, 18768, 18816, 18838, 18945, 18948,
+ 18960, 19008, 20480, 20482, 20485, 20488, 20497, 20500, 20502, 20505, 20512, 20514, 20517, 20520, 20545, 20548,
+ 20550, 20553, 20560, 20562, 20565, 20568, 20577, 20580, 20608, 20610, 20613, 20616, 20625, 20628, 20737, 20740,
+ 20742, 20745, 20752, 20754, 20757, 20760, 20769, 20772, 20800, 20802, 20805, 20808, 20817, 20820, 20832, 20865,
+ 20868, 20880, 20992, 20997, 21000, 21009, 21012, 21024, 21057, 21060, 21072, 21097, 21120, 21505, 21508, 21510,
+ 21513, 21520, 21522, 21525, 21528, 21537, 21540, 21568, 21570, 21573, 21576, 21585, 21588, 21600, 21633, 21636,
+ 21648, 21760, 21762, 21765, 21768, 21777, 21780, 21792, 21825, 21828, 21840, 21888, 22017, 22020, 22032, 22054,
+ 22080, 22528, 22530, 22533, 22536, 22545, 22548, 22560, 22593, 22596, 22608, 22618, 22656, 22785, 22788, 22800,
+ 22848, 23040, 23065, 23173, 23208, 24577, 24580, 24582, 24592, 24594, 24597, 24600, 24609, 24612, 24640, 24645,
+ 24648, 24657, 24660, 24672, 24708, 24720, 24832, 24834, 24837, 24840, 24849, 24852, 24864, 24897, 24900, 24912,
+ 24960, 24985, 25092, 25104, 25152, 25174, 25249, 25600, 25605, 25608, 25617, 25620, 25632, 25665, 25668, 25680,
+ 25728, 25857, 25860, 25872, 25920, 25930, 25960, 26002, 26112, 26260, 26625, 26628, 26640, 26725, 26776, 26880,
+ 26922, 27202, 27297, 32768, 32770, 32773, 32776, 32785, 32788, 32793, 32800, 32805, 32833, 32836, 32848, 32850,
+ 32853, 32856, 32865, 32896, 32901, 32913, 32916, 33025, 33028, 33033, 33040, 33042, 33045, 33048, 33057, 33060,
+ 33088, 33090, 33093, 33096, 33105, 33108, 33153, 33156, 33168, 33193, 33280, 33285, 33290, 33297, 33300, 33345,
+ 33348, 33360, 33793, 33796, 33798, 33801, 33808, 33810, 33813, 33816, 33825, 33856, 33858, 33861, 33864, 33873,
+ 33876, 33888, 33921, 33924, 33936, 34048, 34050, 34053, 34056, 34065, 34068, 34080, 34113, 34116, 34128, 34176,
+ 34186, 34305, 34308, 34320, 34345, 34368, 34816, 34821, 34833, 34836, 34881, 34884, 34896, 34978, 35073, 35076,
+ 35136, 35173, 35362, 35416, 35418, 35458, 35490, 36865, 36868, 36873, 36880, 36882, 36885, 36888, 36900, 36928,
+ 36930, 36933, 36936, 36945, 36948, 36960, 36993, 36996, 37008, 37120, 37125, 37137, 37140, 37185, 37188, 37200,
+ 37210, 37377, 37380, 37392, 37440, 37542, 37888, 37890, 37893, 37896, 37905, 37908, 37920, 37953, 37956, 37968,
+ 38016, 38038, 38145, 38148, 38160, 38208, 38296, 38305, 38400, 38470, 38500, 38913, 38916, 38928, 38950, 38976,
+ 39081, 39168, 39241, 39250, 39568, 40960, 40965, 40970, 40980, 40994, 41002, 41025, 41028, 41040, 41122, 41130,
+ 41280, 41317, 41474, 41482, 41506, 41512, 41514, 41602, 41608, 41610, 41640, 41985, 41988, 42000, 42048, 42121,
+ 42148, 42240, 42265, 42577, 43018, 43048, 43170, 43348, 43398, 43528, 43530, 43552, 43554, 43560, 43656, 43690,
+ };
+
+ const int kmap_size = 43692;
+ //const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2;
+ const int nwant = type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2;
+ const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :
+ type == GGML_TYPE_IQ2_XS ? kgrid_2bit_512 :
+ type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? kgrid_1bit_2048 : kgrid_2bit_1024;
+ uint64_t * kgrid_q2xs;
+ int * kmap_q2xs;
+ uint16_t * kneighbors_q2xs;
+
+ //printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
+ uint64_t * the_grid = (uint64_t *)malloc(grid_size*sizeof(uint64_t));
+ for (int k = 0; k < grid_size; ++k) {
+ int8_t * pos = (int8_t *)(the_grid + k);
+ for (int i = 0; i < 8; ++i) {
+ int l = (kgrid[k] >> 2*i) & 0x3;
+ pos[i] = 2*l + 1;
+ }
+ }
+ kgrid_q2xs = the_grid;
+ iq2_data[gindex].grid = the_grid;
+ kmap_q2xs = (int *)malloc(kmap_size*sizeof(int));
+ iq2_data[gindex].map = kmap_q2xs;
+ for (int i = 0; i < kmap_size; ++i) kmap_q2xs[i] = -1;
+ uint64_t aux64;
+ uint8_t * aux8 = (uint8_t *)&aux64;
+ for (int i = 0; i < grid_size; ++i) {
+ aux64 = kgrid_q2xs[i];
+ uint16_t index = 0;
+ for (int k=0; k<8; ++k) {
+ uint16_t q = (aux8[k] - 1)/2;
+ index |= (q << 2*k);
+ }
+ kmap_q2xs[index] = i;
+ }
+ int8_t pos[8];
+ int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
+ int num_neighbors = 0, num_not_in_map = 0;
+ for (int i = 0; i < kmap_size; ++i) {
+ if (kmap_q2xs[i] >= 0) continue;
+ ++num_not_in_map;
+ for (int k = 0; k < 8; ++k) {
+ int l = (i >> 2*k) & 0x3;
+ pos[k] = 2*l + 1;
+ }
+ for (int j = 0; j < grid_size; ++j) {
+ const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
+ int d2 = 0;
+ for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+ dist2[2*j+0] = d2;
+ dist2[2*j+1] = j;
+ }
+ qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
+ int n = 0; int d2 = dist2[0];
+ int nhave = 1;
+ for (int j = 0; j < grid_size; ++j) {
+ if (dist2[2*j] > d2) {
+ if (nhave == nwant) break;
+ d2 = dist2[2*j];
+ ++nhave;
+ }
+ ++n;
+ }
+ num_neighbors += n;
+ }
+ //printf("%s: %d neighbours in total\n", __func__, num_neighbors);
+ kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
+ iq2_data[gindex].neighbours = kneighbors_q2xs;
+ int counter = 0;
+ for (int i = 0; i < kmap_size; ++i) {
+ if (kmap_q2xs[i] >= 0) continue;
+ for (int k = 0; k < 8; ++k) {
+ int l = (i >> 2*k) & 0x3;
+ pos[k] = 2*l + 1;
+ }
+ for (int j = 0; j < grid_size; ++j) {
+ const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
+ int d2 = 0;
+ for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+ dist2[2*j+0] = d2;
+ dist2[2*j+1] = j;
+ }
+ qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
+ kmap_q2xs[i] = -(counter + 1);
+ int d2 = dist2[0];
+ uint16_t * start = &kneighbors_q2xs[counter++];
+ int n = 0, nhave = 1;
+ for (int j = 0; j < grid_size; ++j) {
+ if (dist2[2*j] > d2) {
+ if (nhave == nwant) break;
+ d2 = dist2[2*j];
+ ++nhave;
+ }
+ kneighbors_q2xs[counter++] = dist2[2*j+1];
+ ++n;
+ }
+ *start = n;
+ }
+ free(dist2);
+}
+
+void iq2xs_free_impl(enum ggml_type type) {
+ GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);
+ const int gindex = iq2_data_index(type);
+ if (iq2_data[gindex].grid) {
+ free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL;
+ free(iq2_data[gindex].map); iq2_data[gindex].map = NULL;
+ free(iq2_data[gindex].neighbours); iq2_data[gindex].neighbours = NULL;
+ }
+}
+
+static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
+ const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
+ int num_neighbors = neighbours[0];
+ GGML_ASSERT(num_neighbors > 0);
+ float best_d2 = FLT_MAX;
+ int grid_index = -1;
+ for (int j = 1; j <= num_neighbors; ++j) {
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+ float d2 = 0;
+ for (int i = 0; i < 8; ++i) {
+ float q = pg[i];
+ float diff = scale*q - xval[i];
+ d2 += weight[i]*diff*diff;
+ }
+ if (d2 < best_d2) {
+ best_d2 = d2; grid_index = neighbours[j];
+ }
+ }
+ GGML_ASSERT(grid_index >= 0);
+ const int8_t * pg = (const int8_t *)(grid + grid_index);
+ for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
+ return grid_index;
+}
+
+static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) {
+
+ const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS);
+
+ const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
+ const int * kmap_q2xs = iq2_data[gindex].map;
+ const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+ GGML_ASSERT(quant_weights && "missing quantization weights");
+ GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(n%QK_K == 0);
+
+ const int kMaxQ = 3;
+
+ const int64_t nbl = n/QK_K;
+
+ block_iq2_xxs * y = vy;
+
+ float scales[QK_K/32];
+ float weight[32];
+ float xval[32];
+ int8_t L[32];
+ int8_t Laux[32];
+ float waux[32];
+ uint8_t block_signs[4];
+ uint32_t q2[2*(QK_K/32)];
+
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
+ memset(q2, 0, QK_K/4);
+
+ float max_scale = 0;
+
+ const float * xbl = x + QK_K*ibl;
+ float sumx2 = 0;
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+ float sigma2 = sumx2/QK_K;
+
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ const float * xb = xbl + 32*ib;
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+ for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
+ for (int k = 0; k < 4; ++k) {
+ int nflip = 0;
+ uint8_t s = 0;
+ for (int i = 0; i < 8; ++i) {
+ if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+ else {
+ xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
+ }
+ }
+ if (nflip%2) {
+ int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
+ for (int i = 1; i < 8; ++i) {
+ float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
+ if (ax < min) {
+ min = ax; imin = i;
+ }
+ }
+ xval[8*k+imin] = -xval[8*k+imin];
+ s ^= (1 << imin);
+ }
+ block_signs[k] = s & 127;
+ }
+ float max = xval[0];
+ for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
+ if (max < GROUP_MAX_EPS) {
+ scales[ib] = 0;
+ memset(L, 0, 32);
+ continue;
+ }
+ float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight);
+ float eff_max = scale*kMaxQ;
+ float best = 0;
+ for (int is = -6; is <= 6; ++is) {
+ float id = (2*kMaxQ-1+is*0.1f)/eff_max;
+ float this_scale = 1/id;
+ for (int k = 0; k < 4; ++k) {
+ for (int i = 0; i < 8; ++i) {
+ int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+ Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
+ }
+ uint16_t u = 0;
+ for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
+ int grid_index = kmap_q2xs[u];
+ if (grid_index < 0) {
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+ grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
+ }
+ }
+ float sumqx = 0, sumq2 = 0;
+ for (int i = 0; i < 32; ++i) {
+ float w = weight[i];
+ float q = 2*Laux[i] + 1;
+ sumqx += w*xval[i]*q;
+ sumq2 += w*q*q;
+ }
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ scale = sumqx/sumq2; best = scale*sumqx;
+ memcpy(L, Laux, 32);
+ }
+ }
+ if (scale > 0) {
+ float id = 1/scale;
+ for (int k = 0; k < 4; ++k) {
+ uint16_t u = 0;
+ for (int i = 0; i < 8; ++i) {
+ int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+ l = MAX(0, MIN(kMaxQ-1, l));
+ u |= (l << 2*i);
+ }
+ int grid_index = kmap_q2xs[u];
+ if (grid_index < 0) {
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+ grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
+ }
+ const int8_t * pg = (const int8_t *)(kgrid_q2xs + grid_index);
+ for (int i = 0; i < 8; ++i) L[8*k+i] = (pg[i] - 1)/2;
+ }
+ float sumqx = 0, sumq2 = 0;
+ for (int i = 0; i < 32; ++i) {
+ float w = weight[i];
+ float q = 2*L[i] + 1;
+ sumqx += w*xval[i]*q;
+ sumq2 += w*q*q;
+ }
+ if (sumq2 > 0) scale = sumqx/sumq2;
+ }
+ if (scale < 0) {
+ // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
+ // and correspondingly flip quant signs.
+ scale = -scale;
+ for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
+ }
+ for (int k = 0; k < 4; ++k) {
+ uint16_t u = 0;
+ for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
+ int grid_index = kmap_q2xs[u];
+ if (grid_index < 0) {
+ printf("Oops: found point %u not on grid:", u);
+ for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
+ printf("\n");
+ GGML_ASSERT(false);
+ }
+ q2[2*ib+0] |= ((uint32_t) grid_index << 8*k);
+ q2[2*ib+1] |= (block_signs[k] << 7*k);
+ }
+ GGML_ASSERT(scale >= 0);
+ scales[ib] = scale;
+ max_scale = MAX(max_scale, scale);
+ }
+
+ if (!max_scale) {
+ memset(y[ibl].qs, 0, QK_K/4);
+ continue;
+ }
+
+ float d = max_scale/31;
+ y[ibl].d = GGML_FP32_TO_FP16(d);
+ float id = 1/d;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
+ l = MAX(0, MIN(15, l));
+ q2[2*ib+1] |= ((uint32_t)l << 28);
+ }
+ memcpy(y[ibl].qs, q2, QK_K/4);
+ }
+}
+
+static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) {
+
+ const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS);
+
+ const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
+ const int * kmap_q2xs = iq2_data[gindex].map;
+ const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+ GGML_ASSERT(quant_weights && "missing quantization weights");
+ GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(n%QK_K == 0);
+
+ const int kMaxQ = 3;
+
+ const int64_t nbl = n/QK_K;
+
+ block_iq2_xs * y = vy;
+
+ float scales[QK_K/16];
+ float weight[16];
+ float xval[16];
+ int8_t L[16];
+ int8_t Laux[16];
+ float waux[16];
+ bool is_on_grid[2];
+ bool is_on_grid_aux[2];
+ uint8_t block_signs[2];
+ uint16_t q2[2*(QK_K/16)];
+
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
+ memset(q2, 0, QK_K/4);
+ memset(y[ibl].scales, 0, QK_K/32);
+
+ float max_scale = 0;
+
+ const float * xbl = x + QK_K*ibl;
+ float sumx2 = 0;
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+ float sigma2 = sumx2/QK_K;
+
+ for (int ib = 0; ib < QK_K/16; ++ib) {
+ const float * xb = xbl + 16*ib;
+ const float * qw = quant_weights + QK_K*ibl + 16*ib;
+ for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+ for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]);
+ for (int k = 0; k < 2; ++k) {
+ int nflip = 0;
+ uint8_t s = 0;
+ for (int i = 0; i < 8; ++i) {
+ if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+ else {
+ xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
+ }
+ }
+ if (nflip%2) {
+ int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
+ for (int i = 1; i < 8; ++i) {
+ float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
+ if (ax < min) {
+ min = ax; imin = i;
+ }
+ }
+ xval[8*k+imin] = -xval[8*k+imin];
+ s ^= (1 << imin);
+ }
+ block_signs[k] = s & 127;
+ }
+ float max = xval[0];
+ for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
+ if (max < GROUP_MAX_EPS) {
+ scales[ib] = 0;
+ memset(L, 0, 16);
+ continue;
+ }
+ float best = 0;
+ float scale = max/(2*kMaxQ-1);
+ is_on_grid[0] = is_on_grid[1] = true;
+ for (int is = -9; is <= 9; ++is) {
+ float id = (2*kMaxQ-1+is*0.1f)/max;
+ float this_scale = 1/id;
+ for (int k = 0; k < 2; ++k) {
+ for (int i = 0; i < 8; ++i) {
+ int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+ Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
+ }
+ uint16_t u = 0;
+ for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
+ int grid_index = kmap_q2xs[u];
+ is_on_grid_aux[k] = true;
+ if (grid_index < 0) {
+ is_on_grid_aux[k] = false;
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+ grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
+ }
+ }
+ float sumqx = 0, sumq2 = 0;
+ for (int i = 0; i < 16; ++i) {
+ float w = weight[i];
+ float q = 2*Laux[i] + 1;
+ sumqx += w*xval[i]*q;
+ sumq2 += w*q*q;
+ }
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ scale = sumqx/sumq2; best = scale*sumqx;
+ for (int i = 0; i < 16; ++i) L[i] = Laux[i];
+ for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k];
+ }
+ }
+ int n_not_ongrid = 0;
+ for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
+ if (n_not_ongrid > 0 && scale > 0) {
+ float id = 1/scale;
+ for (int k = 0; k < 2; ++k) {
+ if (is_on_grid[k]) continue;
+ uint16_t u = 0;
+ for (int i = 0; i < 8; ++i) {
+ int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+ l = MAX(0, MIN(kMaxQ-1, l));
+ u |= (l << 2*i);
+ L[8*k + i] = l;
+ }
+ int grid_index = kmap_q2xs[u];
+ if (grid_index < 0) {
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+ grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
+ }
+ }
+ float sumqx = 0, sumq2 = 0;
+ for (int i = 0; i < 16; ++i) {
+ float w = weight[i];
+ float q = 2*L[i] + 1;
+ sumqx += w*xval[i]*q;
+ sumq2 += w*q*q;
+ }
+ if (sumq2 > 0) scale = sumqx/sumq2;
+ }
+ if (scale < 0) {
+ scale = -scale;
+ for (int k = 0; k < 2; ++k) block_signs[k] = (~block_signs[k]) & 127;
+ }
+ for (int k = 0; k < 2; ++k) {
+ uint16_t u = 0;
+ for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
+ int grid_index = kmap_q2xs[u];
+ if (grid_index < 0) {
+ printf("Oops: found point %u not on grid:", u);
+ for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
+ printf("\n");
+ GGML_ASSERT(false);
+ }
+ q2[2*ib+k] = grid_index | (block_signs[k] << 9);
+ }
+ GGML_ASSERT(scale >= 0);
+ scales[ib] = scale;
+ max_scale = MAX(max_scale, scale);
+ }
+
+ if (!max_scale) {
+ memset(y[ibl].qs, 0, QK_K/4);
+ continue;
+ }
+
+ float d = max_scale/31;
+ y[ibl].d = GGML_FP32_TO_FP16(d);
+ float id = 1/d;
+ for (int ib = 0; ib < QK_K/16; ++ib) {
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
+ l = MAX(0, MIN(15, l));
+ if (ib%2 == 0) y[ibl].scales[ib/2] = l;
+ else y[ibl].scales[ib/2] |= (l << 4);
+ }
+ memcpy(y[ibl].qs, q2, QK_K/4);
+
+ }
+}
+
+size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ int64_t nblock = n_per_row/QK_K;
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += nblock*sizeof(block_iq2_xxs);
+ }
+ return nrow * nblock * sizeof(block_iq2_xxs);
+}
+
+size_t quantize_iq2_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ int64_t nblock = n_per_row/QK_K;
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += nblock*sizeof(block_iq2_xs);
+ }
+ return nrow * nblock * sizeof(block_iq2_xs);
+}
+
+//
+// ============================================= 3-bit using D4 lattice
+//
+
+typedef struct {
+ uint32_t * grid;
+ int * map;
+ uint16_t * neighbours;
+} iq3_entry_t;
+
+static iq3_entry_t iq3_data[2] = {
+ {NULL, NULL, NULL},
+ {NULL, NULL, NULL},
+};
+
+static inline int iq3_data_index(int grid_size) {
+ (void)grid_size;
+ GGML_ASSERT(grid_size == 256 || grid_size == 512);
+ return grid_size == 256 ? 0 : 1;
+}
+
+static int iq3_compare_func(const void * left, const void * right) {
+ const int * l = (const int *)left;
+ const int * r = (const int *)right;
+ return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
+}
+
+void iq3xs_init_impl(int grid_size) {
+ const int gindex = iq3_data_index(grid_size);
+ if (iq3_data[gindex].grid) {
+ return;
+ }
+ static const uint16_t kgrid_256[256] = {
+ 0, 2, 4, 9, 11, 15, 16, 18, 25, 34, 59, 61, 65, 67, 72, 74,
+ 81, 85, 88, 90, 97, 108, 120, 128, 130, 132, 137, 144, 146, 153, 155, 159,
+ 169, 175, 189, 193, 199, 200, 202, 213, 248, 267, 287, 292, 303, 315, 317, 321,
+ 327, 346, 362, 413, 436, 456, 460, 462, 483, 497, 513, 515, 520, 522, 529, 531,
+ 536, 538, 540, 551, 552, 576, 578, 585, 592, 594, 641, 643, 648, 650, 657, 664,
+ 698, 704, 706, 720, 729, 742, 758, 769, 773, 808, 848, 852, 870, 889, 901, 978,
+ 992, 1024, 1026, 1033, 1035, 1040, 1042, 1046, 1049, 1058, 1089, 1091, 1093, 1096, 1098, 1105,
+ 1112, 1139, 1143, 1144, 1152, 1154, 1161, 1167, 1168, 1170, 1183, 1184, 1197, 1217, 1224, 1228,
+ 1272, 1276, 1309, 1323, 1347, 1367, 1377, 1404, 1473, 1475, 1486, 1509, 1537, 1544, 1546, 1553,
+ 1555, 1576, 1589, 1594, 1600, 1602, 1616, 1625, 1636, 1638, 1665, 1667, 1672, 1685, 1706, 1722,
+ 1737, 1755, 1816, 1831, 1850, 1856, 1862, 1874, 1901, 1932, 1950, 1971, 2011, 2032, 2052, 2063,
+ 2077, 2079, 2091, 2095, 2172, 2192, 2207, 2208, 2224, 2230, 2247, 2277, 2308, 2345, 2356, 2389,
+ 2403, 2424, 2501, 2504, 2506, 2520, 2570, 2593, 2616, 2624, 2630, 2646, 2669, 2700, 2714, 2746,
+ 2754, 2795, 2824, 2835, 2839, 2874, 2882, 2905, 2984, 3028, 3042, 3092, 3108, 3110, 3124, 3153,
+ 3185, 3215, 3252, 3288, 3294, 3364, 3397, 3434, 3483, 3523, 3537, 3587, 3589, 3591, 3592, 3610,
+ 3626, 3670, 3680, 3722, 3749, 3754, 3776, 3789, 3803, 3824, 3857, 3873, 3904, 3906, 3924, 3992,
+ };
+ static const uint16_t kgrid_512[512] = {
+ 0, 1, 2, 5, 7, 8, 9, 10, 12, 14, 16, 17, 21, 27, 32, 34,
+ 37, 39, 41, 43, 48, 50, 57, 60, 63, 64, 65, 66, 68, 72, 73, 77,
+ 80, 83, 87, 89, 93, 100, 113, 117, 122, 128, 129, 133, 135, 136, 139, 142,
+ 145, 149, 152, 156, 162, 165, 167, 169, 171, 184, 187, 195, 201, 205, 208, 210,
+ 217, 219, 222, 228, 232, 234, 247, 249, 253, 256, 267, 271, 273, 276, 282, 288,
+ 291, 297, 312, 322, 324, 336, 338, 342, 347, 353, 357, 359, 374, 379, 390, 393,
+ 395, 409, 426, 441, 448, 450, 452, 464, 466, 470, 475, 488, 492, 512, 513, 514,
+ 516, 520, 521, 523, 525, 527, 528, 530, 537, 540, 542, 556, 558, 561, 570, 576,
+ 577, 579, 582, 584, 588, 593, 600, 603, 609, 616, 618, 632, 638, 640, 650, 653,
+ 655, 656, 660, 666, 672, 675, 685, 688, 698, 705, 708, 711, 712, 715, 721, 727,
+ 728, 732, 737, 754, 760, 771, 773, 778, 780, 793, 795, 802, 806, 808, 812, 833,
+ 840, 843, 849, 856, 858, 873, 912, 916, 919, 932, 934, 961, 963, 968, 970, 977,
+ 989, 993, 1010, 1016, 1024, 1025, 1027, 1029, 1031, 1032, 1034, 1036, 1038, 1041, 1043, 1047,
+ 1048, 1050, 1057, 1059, 1061, 1064, 1066, 1079, 1080, 1083, 1085, 1088, 1090, 1096, 1099, 1103,
+ 1106, 1109, 1113, 1116, 1122, 1129, 1153, 1156, 1159, 1169, 1171, 1176, 1183, 1185, 1195, 1199,
+ 1209, 1212, 1216, 1218, 1221, 1225, 1234, 1236, 1241, 1243, 1250, 1256, 1270, 1281, 1287, 1296,
+ 1299, 1306, 1309, 1313, 1338, 1341, 1348, 1353, 1362, 1375, 1376, 1387, 1400, 1408, 1410, 1415,
+ 1425, 1453, 1457, 1477, 1481, 1494, 1496, 1507, 1512, 1538, 1545, 1547, 1549, 1551, 1554, 1561,
+ 1563, 1565, 1570, 1572, 1575, 1577, 1587, 1593, 1601, 1603, 1605, 1612, 1617, 1619, 1632, 1648,
+ 1658, 1662, 1664, 1674, 1680, 1690, 1692, 1704, 1729, 1736, 1740, 1745, 1747, 1751, 1752, 1761,
+ 1763, 1767, 1773, 1787, 1795, 1801, 1806, 1810, 1817, 1834, 1840, 1844, 1857, 1864, 1866, 1877,
+ 1882, 1892, 1902, 1915, 1934, 1953, 1985, 1987, 2000, 2002, 2013, 2048, 2052, 2058, 2064, 2068,
+ 2071, 2074, 2081, 2088, 2104, 2114, 2119, 2121, 2123, 2130, 2136, 2141, 2147, 2153, 2157, 2177,
+ 2179, 2184, 2189, 2193, 2203, 2208, 2223, 2226, 2232, 2244, 2249, 2251, 2256, 2258, 2265, 2269,
+ 2304, 2306, 2324, 2335, 2336, 2361, 2373, 2375, 2385, 2418, 2443, 2460, 2480, 2504, 2509, 2520,
+ 2531, 2537, 2562, 2568, 2572, 2578, 2592, 2596, 2599, 2602, 2614, 2620, 2625, 2627, 2629, 2634,
+ 2641, 2650, 2682, 2688, 2697, 2707, 2712, 2718, 2731, 2754, 2759, 2760, 2775, 2788, 2793, 2805,
+ 2811, 2817, 2820, 2832, 2842, 2854, 2890, 2902, 2921, 2923, 2978, 3010, 3012, 3026, 3081, 3083,
+ 3085, 3097, 3099, 3120, 3136, 3152, 3159, 3188, 3210, 3228, 3234, 3245, 3250, 3256, 3264, 3276,
+ 3281, 3296, 3349, 3363, 3378, 3392, 3395, 3420, 3440, 3461, 3488, 3529, 3531, 3584, 3588, 3591,
+ 3600, 3602, 3614, 3616, 3628, 3634, 3650, 3657, 3668, 3683, 3685, 3713, 3716, 3720, 3726, 3729,
+ 3736, 3753, 3778, 3802, 3805, 3819, 3841, 3845, 3851, 3856, 3880, 3922, 3938, 3970, 3993, 4032,
+ };
+
+ const int kmap_size = 4096;
+ const int nwant = grid_size == 256 ? 2 : 3;
+ const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
+ uint32_t * kgrid_q3xs;
+ int * kmap_q3xs;
+ uint16_t * kneighbors_q3xs;
+
+ //printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
+ uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t));
+ for (int k = 0; k < grid_size; ++k) {
+ int8_t * pos = (int8_t *)(the_grid + k);
+ for (int i = 0; i < 4; ++i) {
+ int l = (kgrid[k] >> 3*i) & 0x7;
+ pos[i] = 2*l + 1;
+ }
+ }
+ kgrid_q3xs = the_grid;
+ iq3_data[gindex].grid = the_grid;
+ kmap_q3xs = (int *)malloc(kmap_size*sizeof(int));
+ iq3_data[gindex].map = kmap_q3xs;
+ for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1;
+ uint32_t aux32;
+ uint8_t * aux8 = (uint8_t *)&aux32;
+ for (int i = 0; i < grid_size; ++i) {
+ aux32 = kgrid_q3xs[i];
+ uint16_t index = 0;
+ for (int k=0; k<4; ++k) {
+ uint16_t q = (aux8[k] - 1)/2;
+ index |= (q << 3*k);
+ }
+ kmap_q3xs[index] = i;
+ }
+ int8_t pos[4];
+ int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
+ int num_neighbors = 0, num_not_in_map = 0;
+ for (int i = 0; i < kmap_size; ++i) {
+ if (kmap_q3xs[i] >= 0) continue;
+ ++num_not_in_map;
+ for (int k = 0; k < 4; ++k) {
+ int l = (i >> 3*k) & 0x7;
+ pos[k] = 2*l + 1;
+ }
+ for (int j = 0; j < grid_size; ++j) {
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
+ int d2 = 0;
+ for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+ dist2[2*j+0] = d2;
+ dist2[2*j+1] = j;
+ }
+ qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
+ int n = 0; int d2 = dist2[0];
+ int nhave = 1;
+ for (int j = 0; j < grid_size; ++j) {
+ if (dist2[2*j] > d2) {
+ if (nhave == nwant) break;
+ d2 = dist2[2*j];
+ ++nhave;
+ }
+ ++n;
+ }
+ num_neighbors += n;
+ }
+ //printf("%s: %d neighbours in total\n", __func__, num_neighbors);
+ kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
+ iq3_data[gindex].neighbours = kneighbors_q3xs;
+ int counter = 0;
+ for (int i = 0; i < kmap_size; ++i) {
+ if (kmap_q3xs[i] >= 0) continue;
+ for (int k = 0; k < 4; ++k) {
+ int l = (i >> 3*k) & 0x7;
+ pos[k] = 2*l + 1;
+ }
+ for (int j = 0; j < grid_size; ++j) {
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
+ int d2 = 0;
+ for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+ dist2[2*j+0] = d2;
+ dist2[2*j+1] = j;
+ }
+ qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
+ kmap_q3xs[i] = -(counter + 1);
+ int d2 = dist2[0];
+ uint16_t * start = &kneighbors_q3xs[counter++];
+ int n = 0, nhave = 1;
+ for (int j = 0; j < grid_size; ++j) {
+ if (dist2[2*j] > d2) {
+ if (nhave == nwant) break;
+ d2 = dist2[2*j];
+ ++nhave;
+ }
+ kneighbors_q3xs[counter++] = dist2[2*j+1];
+ ++n;
+ }
+ *start = n;
+ }
+ free(dist2);
+}
+
+void iq3xs_free_impl(int grid_size) {
+ GGML_ASSERT(grid_size == 256 || grid_size == 512);
+ const int gindex = iq3_data_index(grid_size);
+ if (iq3_data[gindex].grid) {
+ free(iq3_data[gindex].grid); iq3_data[gindex].grid = NULL;
+ free(iq3_data[gindex].map); iq3_data[gindex].map = NULL;
+ free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL;
+ }
+}
+
+static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const uint32_t * restrict grid,
+ const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
+ int num_neighbors = neighbours[0];
+ GGML_ASSERT(num_neighbors > 0);
+ float best_d2 = FLT_MAX;
+ int grid_index = -1;
+ for (int j = 1; j <= num_neighbors; ++j) {
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+ float d2 = 0;
+ for (int i = 0; i < 4; ++i) {
+ float q = pg[i];
+ float diff = scale*q - xval[i];
+ d2 += weight[i]*diff*diff;
+ }
+ if (d2 < best_d2) {
+ best_d2 = d2; grid_index = neighbours[j];
+ }
+ }
+ GGML_ASSERT(grid_index >= 0);
+ const int8_t * pg = (const int8_t *)(grid + grid_index);
+ for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2;
+ return grid_index;
+}
+
+static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int64_t n,
+ const float * restrict quant_weights) {
+
+ const int gindex = iq3_data_index(grid_size);
+
+ const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
+ const int * kmap_q3xs = iq3_data[gindex].map;
+ const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
+
+ //GGML_ASSERT(quant_weights && "missing quantization weights");
+ GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(n%QK_K == 0);
+
+ const int kMaxQ = 8;
+
+ const int64_t nbl = n/QK_K;
+
+ ggml_fp16_t * dh;
+ uint8_t * qs;
+ int block_size;
+ if (grid_size == 256) {
+ block_iq3_xxs * y = vy;
+ dh = &y->d;
+ qs = y->qs;
+ block_size = sizeof(block_iq3_xxs);
+ } else {
+ block_iq3_s * y = vy;
+ dh = &y->d;
+ qs = y->qs;
+ block_size = sizeof(block_iq3_s);
+ }
+ int quant_size = block_size - sizeof(ggml_fp16_t);
+
+ float scales[QK_K/32];
+ float weight[32];
+ float xval[32];
+ int8_t L[32];
+ int8_t Laux[32];
+ float waux[32];
+ bool is_on_grid[8];
+ bool is_on_grid_aux[8];
+ uint8_t block_signs[8];
+ uint8_t q3[3*(QK_K/8)+QK_K/32];
+ uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
+ uint8_t * qh = q3 + 3*(QK_K/8);
+
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+
+ dh[0] = GGML_FP32_TO_FP16(0.f);
+ memset(q3, 0, 3*QK_K/8+QK_K/32);
+
+ float max_scale = 0;
+
+ const float * xbl = x + QK_K*ibl;
+ float sumx2 = 0;
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+ float sigma2 = 2*sumx2/QK_K;
+
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ const float * xb = xbl + 32*ib;
+ if (quant_weights) {
+ const float * qw = quant_weights + QK_K*ibl + 32*ib;
+ for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+ } else {
+ for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
+ }
+ for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
+ for (int k = 0; k < 4; ++k) {
+ int nflip = 0;
+ uint8_t s = 0;
+ for (int i = 0; i < 8; ++i) {
+ if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+ else {
+ xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
+ }
+ }
+ if (nflip%2) {
+ int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
+ for (int i = 1; i < 8; ++i) {
+ float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
+ if (ax < min) {
+ min = ax; imin = i;
+ }
+ }
+ xval[8*k+imin] = -xval[8*k+imin];
+ s ^= (1 << imin);
+ }
+ block_signs[k] = s & 127;
+ }
+ float max = xval[0];
+ for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
+ if (max < GROUP_MAX_EPS_IQ3_XXS) {
+ scales[ib] = 0;
+ memset(L, 0, 32);
+ continue;
+ }
+ float best = 0;
+ float scale = max/(2*kMaxQ-1);
+ for (int is = -15; is <= 15; ++is) {
+ float id = (2*kMaxQ-1+is*0.2f)/max;
+ float this_scale = 1/id;
+ for (int k = 0; k < 8; ++k) {
+ for (int i = 0; i < 4; ++i) {
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
+ Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
+ }
+ uint16_t u = 0;
+ for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
+ int grid_index = kmap_q3xs[u];
+ is_on_grid_aux[k] = true;
+ if (grid_index < 0) {
+ is_on_grid_aux[k] = false;
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
+ }
+ }
+ float sumqx = 0, sumq2 = 0;
+ for (int i = 0; i < 32; ++i) {
+ float w = weight[i];
+ float q = 2*Laux[i] + 1;
+ sumqx += w*xval[i]*q;
+ sumq2 += w*q*q;
+ }
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ scale = sumqx/sumq2; best = scale*sumqx;
+ for (int i = 0; i < 32; ++i) L[i] = Laux[i];
+ for (int k = 0; k < 8; ++k) is_on_grid[k] = is_on_grid_aux[k];
+ }
+ }
+ int n_not_ongrid = 0;
+ for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
+ if (n_not_ongrid > 0 && scale > 0) {
+ float id = 1/scale;
+ for (int k = 0; k < 8; ++k) {
+ if (is_on_grid[k]) continue;
+ uint16_t u = 0;
+ for (int i = 0; i < 4; ++i) {
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
+ l = MAX(0, MIN(kMaxQ-1, l));
+ u |= (l << 3*i);
+ }
+ int grid_index = kmap_q3xs[u];
+ if (grid_index < 0) {
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
+ }
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
+ for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
+ }
+ float sumqx = 0, sumq2 = 0;
+ for (int i = 0; i < 32; ++i) {
+ float w = weight[i];
+ float q = 2*L[i] + 1;
+ sumqx += w*xval[i]*q;
+ sumq2 += w*q*q;
+ }
+ if (sumq2 > 0) scale = sumqx/sumq2;
+ }
+ if (scale < 0) {
+ // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
+ // and correspondingly flip quant signs.
+ scale = -scale;
+ for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
+ }
+ for (int k = 0; k < 8; ++k) {
+ uint16_t u = 0;
+ for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
+ int grid_index = kmap_q3xs[u];
+ if (grid_index < 0) {
+ printf("Oops: found point %u not on grid:", u);
+ for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
+ printf("\n");
+ GGML_ASSERT(false);
+ }
+ if (grid_size == 256) {
+ q3[8*ib+k] = grid_index;
+ } else {
+ q3[8*ib+k] = grid_index & 255;
+ qh[ib] |= ((grid_index >> 8) << k);
+ }
+
+ }
+ scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
+ GGML_ASSERT(scale >= 0);
+ scales[ib] = scale;
+ max_scale = MAX(max_scale, scale);
+ }
+
+ if (!max_scale) {
+ memset(qs, 0, quant_size);
+ dh += block_size/sizeof(ggml_fp16_t);
+ qs += block_size;
+ continue;
+ }
+
+ float d = max_scale/31;
+ dh[0] = GGML_FP32_TO_FP16(d * 1.0125f); // small improvement via this fudge factor
+ float id = 1/d;
+ for (int ib = 0; ib < QK_K/32; ++ib) {
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
+ l = MAX(0, MIN(15, l));
+ scales_and_signs[ib] |= ((uint32_t)l << 28);
+ }
+ memcpy(qs, q3, quant_size);
+
+ dh += block_size/sizeof(ggml_fp16_t);
+ qs += block_size;
+
+ }
+}
+
+size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ int64_t nblock = n_per_row/QK_K;
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += nblock*sizeof(block_iq3_xxs);
+ }
+ return nrow * nblock * sizeof(block_iq3_xxs);
+}
+
+void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_iq3_xxs * restrict y = vy;
+ quantize_row_iq3_xxs_ref(x, y, k);
+}
+
+void quantize_row_iq3_xxs_ref(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
+}
+
+static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, void * restrict vy, int n,
+ const float * restrict quant_weights,
+ float * scales,
+ float * weight,
+ float * xval,
+ int8_t * L,
+ int8_t * Laux,
+ float * waux,
+ bool * is_on_grid,
+ bool * is_on_grid_aux,
+ uint8_t * block_signs) {
+
+ const int gindex = iq3_data_index(512);
+
+ const uint32_t * kgrid_q3xs = iq3_data[gindex].grid;
+ const int * kmap_q3xs = iq3_data[gindex].map;
+ const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
+
+ //GGML_ASSERT(quant_weights && "missing quantization weights");
+ GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(n%QK_K == 0);
+
+ const int kMaxQ = 8;
+
+ const int64_t nbl = n/QK_K;
+
+ block_iq3_s * y = vy;
+
+ const int bs4 = block_size/4;
+ const int bs8 = block_size/8;
+
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+
+ memset(&y[ibl], 0, sizeof(block_iq3_s));
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
+
+ uint8_t * qs = y[ibl].qs;
+ uint8_t * qh = y[ibl].qh;
+ uint8_t * signs = y[ibl].signs;
+
+ float max_scale = 0;
+
+ const float * xbl = x + QK_K*ibl;
+ float sumx2 = 0;
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+ float sigma2 = 2*sumx2/QK_K;
+
+ for (int ib = 0; ib < QK_K/block_size; ++ib) {
+ const float * xb = xbl + block_size*ib;
+ if (quant_weights) {
+ const float * qw = quant_weights + QK_K*ibl + block_size*ib;
+ for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+ } else {
+ for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
+ }
+ for (int i = 0; i < block_size; ++i) waux[i] = sqrtf(weight[i]);
+ for (int k = 0; k < bs8; ++k) {
+ uint8_t s = 0;
+ for (int i = 0; i < 8; ++i) {
+ if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+ else {
+ xval[8*k + i] = -xb[8*k + i]; s |= (1 << i);
+ }
+ }
+ block_signs[k] = s;
+ }
+ float max = xval[0];
+ for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]);
+ if (!max) {
+ scales[ib] = 0;
+ continue;
+ }
+ float best = 0;
+ float scale = max/(2*kMaxQ-1);
+ for (int k = 0; k < bs4; ++k) is_on_grid[k] = false;
+ for (int is = -9; is <= 9; ++is) {
+ float id = (2*kMaxQ-1+is*0.2f)/max;
+ float this_scale = 1/id;
+ for (int k = 0; k < bs4; ++k) {
+ for (int i = 0; i < 4; ++i) {
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
+ Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
+ }
+ uint16_t u = 0;
+ for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
+ int grid_index = kmap_q3xs[u];
+ is_on_grid_aux[k] = true;
+ if (grid_index < 0) {
+ is_on_grid_aux[k] = false;
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
+ }
+ }
+ float sumqx = 0, sumq2 = 0;
+ for (int i = 0; i < block_size; ++i) {
+ float w = weight[i];
+ float q = 2*Laux[i] + 1;
+ sumqx += w*xval[i]*q;
+ sumq2 += w*q*q;
+ }
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ scale = sumqx/sumq2; best = scale*sumqx;
+ for (int i = 0; i < block_size; ++i) L[i] = Laux[i];
+ for (int k = 0; k < bs4; ++k) is_on_grid[k] = is_on_grid_aux[k];
+ }
+ }
+ int n_not_ongrid = 0;
+ for (int k = 0; k < bs4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
+ if (n_not_ongrid > 0 && scale > 0) {
+ float id = 1/scale;
+ for (int k = 0; k < bs4; ++k) {
+ //if (is_on_grid[k]) continue;
+ uint16_t u = 0;
+ for (int i = 0; i < 4; ++i) {
+ int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
+ l = MAX(0, MIN(kMaxQ-1, l));
+ u |= (l << 3*i);
+ }
+ int grid_index = kmap_q3xs[u];
+ if (grid_index < 0) {
+ const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
+ grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
+ }
+ const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
+ for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
+ }
+ float sumqx = 0, sumq2 = 0;
+ for (int i = 0; i < block_size; ++i) {
+ float w = weight[i];
+ float q = 2*L[i] + 1;
+ sumqx += w*xval[i]*q;
+ sumq2 += w*q*q;
+ }
+ if (sumq2 > 0) scale = sumqx/sumq2;
+ }
+ if (scale < 0) {
+ // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
+ // and correspondingly flip quant signs.
+ scale = -scale;
+ for (int k = 0; k < bs8; ++k) block_signs[k] = ~block_signs[k];
+ }
+ for (int k = 0; k < bs4; ++k) {
+ uint16_t u = 0;
+ for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
+ int grid_index = kmap_q3xs[u];
+ if (grid_index < 0) {
+ printf("Oops: found point %u not on grid:", u);
+ for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
+ printf("\n");
+ GGML_ASSERT(false);
+ }
+ qs[k] = grid_index & 255;
+ qh[(ib*bs4+k)/8] |= ((grid_index >> 8) << ((ib*bs4+k)%8));
+ }
+ qs += bs4;
+ for (int k = 0; k < bs8; ++k) signs[k] = block_signs[k];
+ signs += bs8;
+ GGML_ASSERT(scale >= 0);
+ scales[ib] = scale;
+ max_scale = MAX(max_scale, scale);
+ }
+
+ if (!max_scale) {
+ continue;
+ }
+
+ float d = max_scale/31;
+ y[ibl].d = GGML_FP32_TO_FP16(d * 1.033f);
+ float id = 1/d;
+ for (int ib = 0; ib < QK_K/block_size; ib += 2) {
+ int l1 = nearest_int(0.5f*(id*scales[ib+0]-1));
+ l1 = MAX(0, MIN(15, l1));
+ int l2 = nearest_int(0.5f*(id*scales[ib+1]-1));
+ l2 = MAX(0, MIN(15, l2));
+ y[ibl].scales[ib/2] = l1 | (l2 << 4);
+ }
+
+ }
+}
+
+#define IQ3S_BLOCK_SIZE 32
+size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ int64_t nblock = n_per_row/QK_K;
+ float scales[QK_K/IQ3S_BLOCK_SIZE];
+ float weight[IQ3S_BLOCK_SIZE];
+ float xval[IQ3S_BLOCK_SIZE];
+ int8_t L[IQ3S_BLOCK_SIZE];
+ int8_t Laux[IQ3S_BLOCK_SIZE];
+ float waux[IQ3S_BLOCK_SIZE];
+ bool is_on_grid[IQ3S_BLOCK_SIZE/4];
+ bool is_on_grid_aux[IQ3S_BLOCK_SIZE/4];
+ uint8_t block_signs[IQ3S_BLOCK_SIZE/8];
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights,
+ scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs);
+ src += n_per_row;
+ qrow += nblock*sizeof(block_iq3_s);
+ }
+ return nrow * nblock * sizeof(block_iq3_s);
+}
+
+void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_iq3_s * restrict y = vy;
+ quantize_row_iq3_s_ref(x, y, k);
+}
+
+void quantize_row_iq3_s_ref(const float * restrict x, block_iq3_s * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ quantize_iq3_s(x, y, 1, k, NULL);
+}
+
+
+// =================================== 1.5 bpw ===================================================
+
+static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
+ const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) {
+ int num_neighbors = neighbours[0];
+ GGML_ASSERT(num_neighbors > 0);
+ float best_score = -FLT_MAX;
+ int grid_index = -1;
+ for (int j = 1; j <= num_neighbors; ++j) {
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+ float sumqx = 0, sumq2 = 0;
+ for (int i = 0; i < 8; ++i) {
+ float q = (pg[i] - 3)/2;
+ float w = weight[i];
+ sumqx += w*q*xval[i];
+ sumq2 += w*q*q;
+ }
+ if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
+ *scale = sumqx/sumq2; best_score = *scale * sumqx;
+ grid_index = neighbours[j];
+ }
+ }
+ if (grid_index < 0) {
+ for (int i = 0; i < ngrid; ++i) {
+ const int8_t * grid_i = (const int8_t *)(grid + i);
+ float sumqx = 0, sumq2 = 0;
+ for (int j = 0; j < 8; ++j) {
+ float w = weight[j];
+ float q = (grid_i[j] - 3)/2;
+ sumqx += w*q*xval[j];
+ sumq2 += w*q*q;
+ }
+ if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
+ *scale = sumqx/sumq2; best_score = *scale*sumqx;
+ grid_index = i;
+ }
+ }
+ }
+ if (grid_index < 0) {
+ printf("Oops, did not find grid point\n");
+ printf("Have %d neighbours\n", num_neighbors);
+ for (int j = 1; j <= num_neighbors; ++j) {
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+ float sumqx = 0, sumq2 = 0;
+ for (int i = 0; i < 8; ++i) {
+ float q = (pg[i] - 3)/2;
+ float w = weight[i];
+ sumqx += w*q*xval[i];
+ sumq2 += w*q*q;
+ }
+ printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2);
+ }
+ }
+ GGML_ASSERT(grid_index >= 0);
+ //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ *scale *= 1.05f; // This is a fudge factor. Don't ask me why it improves the result.
+ //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ const int8_t * pg = (const int8_t *)(grid + grid_index);
+ for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
+ return grid_index;
+}
+
+static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
+ const float * restrict xval, const float * restrict weight, float scale, const float * restrict xg, int8_t * restrict L, int ngrid) {
+ int num_neighbors = neighbours[0];
+ GGML_ASSERT(num_neighbors > 0);
+ float best_score = FLT_MAX;
+ int grid_index = -1;
+ for (int j = 1; j <= num_neighbors; ++j) {
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+ float d2 = 0;
+ for (int i = 0; i < 8; ++i) {
+ float q = xg[(pg[i] - 1)/2];
+ float w = weight[i];
+ float diff = scale*q - xval[i];
+ d2 += w*diff*diff;
+ }
+ if (d2 < best_score) {
+ best_score = d2;
+ grid_index = neighbours[j];
+ }
+ }
+ if (grid_index < 0) {
+ for (int i = 0; i < ngrid; ++i) {
+ const int8_t * grid_i = (const int8_t *)(grid + i);
+ float d2 = 0;
+ for (int j = 0; j < 8; ++j) {
+ float w = weight[j];
+ float q = xg[(grid_i[j] - 1)/2];
+ float diff = scale*q - xval[i];
+ d2 += w*diff*diff;
+ }
+ if (d2 < best_score) {
+ best_score = d2;
+ grid_index = i;
+ }
+ }
+ }
+ if (grid_index < 0) {
+ printf("Oops, did not find grid point\n");
+ printf("Have %d neighbours\n", num_neighbors);
+ for (int j = 1; j <= num_neighbors; ++j) {
+ const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+ float sumqx = 0, sumq2 = 0;
+ for (int i = 0; i < 8; ++i) {
+ float q = xg[(pg[i] - 1)/2];
+ float w = weight[i];
+ sumqx += w*q*xval[i];
+ sumq2 += w*q*q;
+ }
+ printf(" neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2);
+ }
+ }
+ GGML_ASSERT(grid_index >= 0);
+ const int8_t * pg = (const int8_t *)(grid + grid_index);
+ for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
+ return grid_index;
+}
+
+static int iq1_sort_helper(const void * left, const void * right) {
+ const float * l = left;
+ const float * r = right;
+ return *l < *r ? -1 : *l > *r ? 1 : 0;
+}
+
+#define IQ1S_BLOCK_SIZE 32
+#define IQ1M_BLOCK_SIZE 16
+static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights,
+ float * scales,
+ float * weight,
+ float * sumx,
+ float * sumw,
+ float * pairs,
+ int8_t * L,
+ uint16_t * index,
+ int8_t * shifts) {
+
+ const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
+
+ const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
+ const int * kmap_q2xs = iq2_data[gindex].map;
+ const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+ GGML_ASSERT(quant_weights && "missing quantization weights");
+ GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(n%QK_K == 0);
+
+ block_iq1_s * y = vy;
+
+ const int64_t nbl = n/QK_K;
+
+ const int block_size = IQ1S_BLOCK_SIZE;
+
+ const float x_p[3] = {-1 + IQ1S_DELTA, IQ1S_DELTA, 1 + IQ1S_DELTA};
+ const float x_m[3] = {-1 - IQ1S_DELTA, -IQ1S_DELTA, 1 - IQ1S_DELTA};
+
+
+ int * idx = (int *)(pairs + 1);
+
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
+ memset(y[ibl].qs, 0, QK_K/8);
+ memset(y[ibl].qh, 0, QK_K/16);
+
+ float max_scale = 0;
+
+ const float * xbl = x + QK_K*ibl;
+ float sumx2 = 0;
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+ float sigma2 = 2*sumx2/QK_K;
+
+ for (int ib = 0; ib < QK_K/block_size; ++ib) {
+ const float * xb = xbl + block_size*ib;
+ const float * qw = quant_weights + QK_K*ibl + block_size*ib;
+ for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+ float max = fabsf(xb[0]);
+ for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
+ if (max < GROUP_MAX_EPS_IQ1_S) {
+ scales[ib] = 0;
+ memset(L, 1, block_size);
+ continue;
+ }
+ // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
+ // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
+ // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
+ // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
+ // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
+ // for each possible and score for each split.
+ for (int j = 0; j < block_size; ++j) {
+ pairs[2*j] = xb[j];
+ idx[2*j] = j;
+ }
+ qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
+ {
+ sumx[0] = sumw[0] = 0;
+ for (int j = 0; j < block_size; ++j) {
+ int i = idx[2*j];
+ sumx[j+1] = sumx[j] + weight[i]*xb[i];
+ sumw[j+1] = sumw[j] + weight[i];
+ }
+ }
+ float best_score = -FLT_MIN, scale = max;
+ int besti1 = -1, besti2 = -1, best_shift = 0;
+ for (int i1 = 0; i1 <= block_size; ++i1) {
+ for (int i2 = i1; i2 <= block_size; ++i2) {
+ float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[block_size] - sumx[i2])*x_p[2];
+ float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[block_size] - sumw[i2])*x_p[2]*x_p[2];
+ if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
+ scale = sumqx/sumq2; best_score = scale*sumqx;
+ besti1 = i1; besti2 = i2; best_shift = 1;
+ }
+ sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[block_size] - sumx[i2])*x_m[2];
+ sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[block_size] - sumw[i2])*x_m[2]*x_m[2];
+ if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
+ scale = sumqx/sumq2; best_score = scale*sumqx;
+ besti1 = i1; besti2 = i2; best_shift = -1;
+ }
+ }
+ }
+ GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0);
+ for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
+ for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
+ for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;
+ if (scale < 0) {
+ for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j];
+ scale = -scale; best_shift = -best_shift;
+ }
+ bool all_on_grid = true;
+ const float * xx = best_shift == 1 ? x_p : x_m;
+ for (int k = 0; k < block_size/8; ++k) {
+ uint16_t u = 0;
+ for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
+ int grid_index = kmap_q2xs[u];
+ if (grid_index < 0) {
+ all_on_grid = false;
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+ grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S);
+ GGML_ASSERT(grid_index >= 0);
+ }
+ index[k] = grid_index;
+ }
+ if (!all_on_grid) {
+ float sumqx = 0, sumq2 = 0;
+ for (int k = 0; k < block_size/8; ++k) {
+ const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
+ for (int j = 0; j < 8; ++j) {
+ float w = weight[8*k + j];
+ float q = xx[(pg[j] - 1)/2];
+ sumqx += w*q*xb[8*k+j];
+ sumq2 += w*q*q;
+ }
+ }
+ if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2;
+ }
+ uint16_t h = 0;
+ for (int k = 0; k < block_size/8; ++k) {
+ y[ibl].qs[(block_size/8)*ib + k] = index[k] & 255;
+ h |= (index[k] >> 8) << 3*k;
+ }
+ y[ibl].qh[ib] = h;
+ GGML_ASSERT(scale >= 0);
+ scales[ib] = scale;
+ shifts[ib] = best_shift;
+ max_scale = MAX(max_scale, scale);
+ }
+
+ if (!max_scale) {
+ continue;
+ }
+
+ float d = max_scale/15;
+ y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.125f is another fudge factor. Don't ask me why it is needed.
+ float id = 1/d;
+ for (int ib = 0; ib < QK_K/block_size; ++ib) {
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
+ l = MAX(0, MIN(7, l));
+ if (shifts[ib] == -1) l |= 8;
+ y[ibl].qh[ib] |= (l << 12);
+ }
+ }
+}
+
+size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ float scales[QK_K/IQ1S_BLOCK_SIZE];
+ float weight[IQ1S_BLOCK_SIZE];
+ int8_t L[IQ1S_BLOCK_SIZE];
+ float sumx[IQ1S_BLOCK_SIZE+1];
+ float sumw[IQ1S_BLOCK_SIZE+1];
+ float pairs[2*IQ1S_BLOCK_SIZE];
+ uint16_t index[IQ1S_BLOCK_SIZE/8];
+ int8_t shifts[QK_K/IQ1S_BLOCK_SIZE];
+ int64_t nblock = n_per_row/QK_K;
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights, scales, weight, sumx, sumw, pairs, L, index, shifts);
+ src += n_per_row;
+ qrow += nblock*sizeof(block_iq1_s);
+ }
+ return nrow * nblock * sizeof(block_iq1_s);
+}
+
+static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights,
+ float * scales,
+ float * weight,
+ float * pairs,
+ int8_t * L,
+ uint16_t * index,
+ int8_t * shifts) {
+
+ const int gindex = iq2_data_index(GGML_TYPE_IQ1_M);
+
+ const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
+ const int * kmap_q2xs = iq2_data[gindex].map;
+ const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+ //GGML_ASSERT(quant_weights && "missing quantization weights");
+ GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(n%QK_K == 0);
+
+ block_iq1_m * y = vy;
+
+ const int64_t nbl = n/QK_K;
+
+ const int block_size = IQ1M_BLOCK_SIZE;
+
+ const float x_p[3] = {-1 + IQ1M_DELTA, IQ1M_DELTA, 1 + IQ1M_DELTA};
+ const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA};
+ const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88};
+
+ int * idx = (int *)(pairs + 1);
+
+ float sumqx[4], sumq2[4];
+
+ iq1m_scale_t s;
+ const float * xx;
+
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+ memset(y[ibl].qs, 0, QK_K/8);
+ memset(y[ibl].qh, 0, QK_K/16);
+ memset(y[ibl].scales, 0, QK_K/32);
+
+ float max_scale = 0;
+
+ const float * xbl = x + QK_K*ibl;
+ float sumx2 = 0;
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+ float sigma2 = 2*sumx2/QK_K;
+
+ for (int ib = 0; ib < QK_K/block_size; ++ib) {
+ const float * xb = xbl + block_size*ib;
+ if (quant_weights) {
+ const float * qw = quant_weights + QK_K*ibl + block_size*ib;
+ for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+ } else {
+ for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
+ }
+ float max = fabsf(xb[0]);
+ for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
+ if (max < GROUP_MAX_EPS_IQ1_M) {
+ scales[ib] = 0;
+ memset(L, 1, block_size);
+ continue;
+ }
+ // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
+ // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
+ // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
+ // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
+ // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
+ // for each possible and score for each split.
+ for (int j = 0; j < block_size; ++j) {
+ pairs[2*j] = xb[j];
+ idx[2*j] = j;
+ }
+ qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
+ float best_score = -FLT_MIN, scale = max;
+ int besti1 = -1, besti2 = -1, best_k = -1;
+ // 0: +, +
+ // 1: +, -
+ // 2: -, +
+ // 3: -, -
+ for (int i1 = 0; i1 <= block_size; ++i1) {
+ for (int i2 = i1; i2 <= block_size; ++i2) {
+ memset(sumqx, 0, 4*sizeof(float));
+ memset(sumq2, 0, 4*sizeof(float));
+ for (int j = 0; j < i1; ++j) {
+ int i = idx[2*j];
+ if (i < block_size/2) {
+ sumqx[0] += weight[i]*x_p[0]*xb[i];
+ sumqx[1] += weight[i]*x_p[0]*xb[i];
+ sumqx[2] += weight[i]*x_m[0]*xb[i];
+ sumqx[3] += weight[i]*x_m[0]*xb[i];
+ sumq2[0] += weight[i]*x_p[0]*x_p[0];
+ sumq2[1] += weight[i]*x_p[0]*x_p[0];
+ sumq2[2] += weight[i]*x_m[0]*x_m[0];
+ sumq2[3] += weight[i]*x_m[0]*x_m[0];
+ } else {
+ sumqx[0] += weight[i]*x_p[0]*xb[i];
+ sumqx[2] += weight[i]*x_p[0]*xb[i];
+ sumqx[1] += weight[i]*x_m[0]*xb[i];
+ sumqx[3] += weight[i]*x_m[0]*xb[i];
+ sumq2[0] += weight[i]*x_p[0]*x_p[0];
+ sumq2[2] += weight[i]*x_p[0]*x_p[0];
+ sumq2[1] += weight[i]*x_m[0]*x_m[0];
+ sumq2[3] += weight[i]*x_m[0]*x_m[0];
+ }
+ }
+ for (int j = i1; j < i2; ++j) {
+ int i = idx[2*j];
+ if (i < block_size/2) {
+ sumqx[0] += weight[i]*x_p[1]*xb[i];
+ sumqx[1] += weight[i]*x_p[1]*xb[i];
+ sumqx[2] += weight[i]*x_m[1]*xb[i];
+ sumqx[3] += weight[i]*x_m[1]*xb[i];
+ sumq2[0] += weight[i]*x_p[1]*x_p[1];
+ sumq2[1] += weight[i]*x_p[1]*x_p[1];
+ sumq2[2] += weight[i]*x_m[1]*x_m[1];
+ sumq2[3] += weight[i]*x_m[1]*x_m[1];
+ } else {
+ sumqx[0] += weight[i]*x_p[1]*xb[i];
+ sumqx[2] += weight[i]*x_p[1]*xb[i];
+ sumqx[1] += weight[i]*x_m[1]*xb[i];
+ sumqx[3] += weight[i]*x_m[1]*xb[i];
+ sumq2[0] += weight[i]*x_p[1]*x_p[1];
+ sumq2[2] += weight[i]*x_p[1]*x_p[1];
+ sumq2[1] += weight[i]*x_m[1]*x_m[1];
+ sumq2[3] += weight[i]*x_m[1]*x_m[1];
+ }
+ }
+ for (int j = i2; j < block_size; ++j) {
+ int i = idx[2*j];
+ if (i < block_size/2) {
+ sumqx[0] += weight[i]*x_p[2]*xb[i];
+ sumqx[1] += weight[i]*x_p[2]*xb[i];
+ sumqx[2] += weight[i]*x_m[2]*xb[i];
+ sumqx[3] += weight[i]*x_m[2]*xb[i];
+ sumq2[0] += weight[i]*x_p[2]*x_p[2];
+ sumq2[1] += weight[i]*x_p[2]*x_p[2];
+ sumq2[2] += weight[i]*x_m[2]*x_m[2];
+ sumq2[3] += weight[i]*x_m[2]*x_m[2];
+ } else {
+ sumqx[0] += weight[i]*x_p[2]*xb[i];
+ sumqx[2] += weight[i]*x_p[2]*xb[i];
+ sumqx[1] += weight[i]*x_m[2]*xb[i];
+ sumqx[3] += weight[i]*x_m[2]*xb[i];
+ sumq2[0] += weight[i]*x_p[2]*x_p[2];
+ sumq2[2] += weight[i]*x_p[2]*x_p[2];
+ sumq2[1] += weight[i]*x_m[2]*x_m[2];
+ sumq2[3] += weight[i]*x_m[2]*x_m[2];
+ }
+ }
+ for (int k = 0; k < 4; ++k) {
+ if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) {
+ scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k];
+ besti1 = i1; besti2 = i2; best_k = k;
+ }
+ }
+ }
+ }
+ GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0);
+ for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
+ for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
+ for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;
+ if (scale < 0) {
+ for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j];
+ scale = -scale;
+ best_k = best_k == 0 ? 3 : best_k == 1 ? 2 : best_k == 2 ? 1 : 0;
+ }
+ bool all_on_grid = true;
+ for (int k = 0; k < block_size/8; ++k) {
+ if (k == 0) xx = best_k < 2 ? x_p : x_m;
+ else xx = best_k%2 == 0 ? x_p : x_m;
+ uint16_t u = 0;
+ for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
+ int grid_index = kmap_q2xs[u];
+ if (grid_index < 0) {
+ all_on_grid = false;
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+ grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S);
+ GGML_ASSERT(grid_index >= 0);
+ }
+ index[k] = grid_index;
+ }
+ if (!all_on_grid) {
+ float sumqx_f = 0, sumq2_f = 0;
+ for (int k = 0; k < block_size/8; ++k) {
+ if (k == 0) xx = best_k < 2 ? x_p : x_m;
+ else xx = best_k%2 == 0 ? x_p : x_m;
+ const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
+ for (int j = 0; j < 8; ++j) {
+ float w = weight[8*k + j];
+ float q = xx[(pg[j] - 1)/2];
+ sumqx_f += w*q*xb[8*k+j];
+ sumq2_f += w*q*q;
+ }
+ }
+ if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f;
+ }
+ y[ibl].qs[2*ib + 0] = index[0] & 255;
+ y[ibl].qs[2*ib + 1] = index[1] & 255;
+ y[ibl].qh[ib] = (index[0] >> 8) | ((index[1] >> 8) << 4);
+ GGML_ASSERT(scale >= 0);
+ scales[ib] = scale;
+ shifts[ib] = best_k;
+ max_scale = MAX(max_scale, scale);
+ }
+
+ if (!max_scale) {
+ continue;
+ }
+
+ uint16_t * sc = (uint16_t *)y[ibl].scales;
+ float d = max_scale/15;
+ float id = 1/d;
+ float sumqx_f = 0, sumq2_f = 0;
+ for (int ib = 0; ib < QK_K/block_size; ++ib) {
+ int l = nearest_int(0.5f*(id*scales[ib+0]-1));
+ l = MAX(0, MIN(7, l));
+ sc[ib/4] |= (l << 3*(ib%4));
+ y[ibl].qh[ib] |= masks[shifts[ib]];
+ const float * xb = xbl + block_size*ib;
+ if (quant_weights) {
+ const float * qw = quant_weights + QK_K*ibl + block_size*ib;
+ for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+ } else {
+ for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
+ }
+ for (int k = 0; k < block_size/8; ++k) {
+ if (k == 0) xx = shifts[ib] < 2 ? x_p : x_m;
+ else xx = shifts[ib]%2 == 0 ? x_p : x_m;
+ const int8_t * pg = (const int8_t *)(kgrid_q2xs + y[ibl].qs[2*ib+k] + ((y[ibl].qh[ib] << (8 - 4*k)) & 0x700));
+ for (int j = 0; j < 8; ++j) {
+ float w = weight[8*k + j];
+ float q = xx[(pg[j] - 1)/2]*(2*l+1);
+ sumqx_f += w*q*xb[8*k+j];
+ sumq2_f += w*q*q;
+ }
+ }
+ }
+ if (sumq2_f > 0) d = sumqx_f/sumq2_f;
+ s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
+ sc[0] |= ((s.u16 & 0x000f) << 12);
+ sc[1] |= ((s.u16 & 0x00f0) << 8);
+ sc[2] |= ((s.u16 & 0x0f00) << 4);
+ sc[3] |= ((s.u16 & 0xf000) << 0);
+ }
+}
+
+size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ float scales[QK_K/IQ1M_BLOCK_SIZE];
+ float weight[IQ1M_BLOCK_SIZE];
+ int8_t L[IQ1M_BLOCK_SIZE];
+ float pairs[2*IQ1M_BLOCK_SIZE];
+ uint16_t index[IQ1M_BLOCK_SIZE/8];
+ int8_t shifts[QK_K/IQ1M_BLOCK_SIZE];
+ int64_t nblock = n_per_row/QK_K;
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_iq1_m_impl(src, qrow, n_per_row, quant_weights, scales, weight, pairs, L, index, shifts);
+ src += n_per_row;
+ qrow += nblock*sizeof(block_iq1_m);
+ }
+ return nrow * nblock * sizeof(block_iq1_m);
+}
+
+// ============================ 4-bit non-linear quants
+
+static inline int best_index_int8(int n, const int8_t * val, float x) {
+ if (x <= val[0]) return 0;
+ if (x >= val[n-1]) return n-1;
+ int ml = 0, mu = n-1;
+ while (mu-ml > 1) {
+ int mav = (ml+mu)/2;
+ if (x < val[mav]) mu = mav; else ml = mav;
+ }
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * restrict x,
+ ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
+ float * scales, float * weight, uint8_t * L,
+ const int8_t * values,
+ const float * quant_weights,
+ const int ntry) {
+
+ float sigma2 = 0;
+ for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j];
+ sigma2 *= 2.f/super_block_size;
+
+ memset(q4, 0, super_block_size/2);
+ dh[0] = GGML_FP32_TO_FP16(0.f);
+
+ float max_scale = 0, amax_scale = 0;
+ for (int ib = 0; ib < super_block_size/block_size; ++ib) {
+ const float * xb = x + ib*block_size;
+ uint8_t * Lb = L + ib*block_size;
+ if (quant_weights) {
+ const float * qw = quant_weights + ib*block_size;
+ for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+ } else {
+ for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
+ }
+ float amax = 0, max = 0;
+ for (int j = 0; j < block_size; ++j) {
+ float ax = fabsf(xb[j]);
+ if (ax > amax) {
+ amax = ax; max = xb[j];
+ }
+ }
+ if (amax < GROUP_MAX_EPS) {
+ scales[ib] = 0;
+ continue;
+ }
+ float d = ntry > 0 ? -max/values[0] : max/values[0];
+ float id = 1/d;
+ float sumqx = 0, sumq2 = 0;
+ for (int j = 0; j < block_size; ++j) {
+ float al = id*xb[j];
+ int l = best_index_int8(16, values, al);
+ Lb[j] = l;
+ float q = values[l];
+ float w = weight[j];
+ sumqx += w*q*xb[j];
+ sumq2 += w*q*q;
+ }
+ d = sumqx/sumq2;
+ float best = d*sumqx;
+ for (int itry = -ntry; itry <= ntry; ++itry) {
+ id = (itry + values[0])/max;
+ sumqx = sumq2 = 0;
+ for (int j = 0; j < block_size; ++j) {
+ float al = id*xb[j];
+ int l = best_index_int8(16, values, al);
+ float q = values[l];
+ float w = weight[j];
+ sumqx += w*q*xb[j];
+ sumq2 += w*q*q;
+ }
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ d = sumqx/sumq2; best = d * sumqx;
+ }
+ }
+ scales[ib] = d;
+ float abs_d = fabsf(d);
+ if (abs_d > amax_scale) {
+ amax_scale = abs_d; max_scale = d;
+ }
+ }
+
+ if (super_block_size/block_size > 1) {
+ int nb = super_block_size/block_size;
+ memset(scales_h, 0, ((nb+7)/8)*sizeof(uint16_t));
+ float d = -max_scale/32;
+ dh[0] = GGML_FP32_TO_FP16(d);
+ float id = d ? 1/d : 0.f;
+ for (int ib = 0; ib < super_block_size/block_size; ++ib) {
+ int l = nearest_int(id*scales[ib]);
+ l = MAX(-32, MIN(31, l));
+ float dl = d * l;
+ float idl = dl ? 1/dl : 0.f;
+ uint8_t * Lb = L + ib*block_size;
+ const float * xb = x + ib*block_size;
+ for (int j = 0; j < block_size; ++j) {
+ Lb[j] = best_index_int8(16, values, idl*xb[j]);
+ }
+ l += 32;
+ uint8_t l_l = l & 0xf;
+ uint8_t l_h = l >> 4;
+ if (ib%2 == 0) scales_l[ib/2] = l_l;
+ else scales_l[ib/2] |= (l_l << 4);
+ scales_h[ib/8] |= (l_h << 2*(ib%8));
+ }
+ } else {
+ dh[0] = GGML_FP32_TO_FP16(scales[0]);
+ if (ntry > 0) {
+ float id = scales[0] ? 1/scales[0] : 0;
+ for (int j = 0; j < super_block_size; ++j) {
+ L[j] = best_index_int8(16, values, id*x[j]);
+ }
+ }
+ }
+
+ for (int i = 0; i < super_block_size/32; ++i) {
+ for (int j = 0; j < 16; ++j) {
+ q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
+ }
+ }
+}
+
+size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ GGML_ASSERT(n_per_row%QK4_NL == 0);
+ int64_t nblock = n_per_row/QK4_NL;
+ char * qrow = (char *)dst;
+ uint8_t L[QK4_NL];
+ float weight[QK4_NL];
+ uint16_t unused_h;
+ uint8_t * unused_l = NULL;
+ float scale;
+ for (int64_t row = 0; row < nrow; ++row) {
+ block_iq4_nl * iq4 = (block_iq4_nl *)qrow;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
+ quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
+ &scale, weight, L, kvalues_iq4nl, qw, 7);
+ }
+ src += n_per_row;
+ qrow += nblock*sizeof(block_iq4_nl);
+ }
+ return nrow * nblock * sizeof(block_iq4_nl);
+}
+
+void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k) {
+ GGML_ASSERT(k%QK4_NL == 0);
+ int64_t nblock = k/QK4_NL;
+ uint8_t L[QK4_NL];
+ float weight[QK4_NL];
+ uint16_t unused_h;
+ uint8_t * unused_l = NULL;
+ float scale;
+ block_iq4_nl * iq4 = (block_iq4_nl *)vy;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ quantize_row_iq4_nl_impl(QK4_NL, 32, x + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
+ &scale, weight, L, kvalues_iq4nl, NULL, -1);
+ }
+}
+
+void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) {
+ assert(k % QK4_NL == 0);
+ quantize_row_iq4_nl(x, y, k);
+}
+
+size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ int64_t nblock = n_per_row/QK_K;
+ char * qrow = (char *)dst;
+ uint8_t L[QK_K];
+ float weight[32];
+ float scales[QK_K/32];
+ for (int64_t row = 0; row < nrow; ++row) {
+ block_iq4_xs * iq4 = (block_iq4_xs *)qrow;
+ for (int ibl = 0; ibl < nblock; ++ibl) {
+ const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL;
+ quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l,
+ scales, weight, L, kvalues_iq4nl, qw, 7);
+ }
+ src += n_per_row;
+ qrow += nblock*sizeof(block_iq4_xs);
+ }
+ return nrow * nblock * sizeof(block_iq4_xs);
+}
+
+void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_iq4_xs * restrict y = vy;
+ quantize_row_iq4_xs_ref(x, y, k);
+}
+
+void quantize_row_iq4_xs_ref(const float * restrict x, block_iq4_xs * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ quantize_iq4_xs(x, y, 1, k, NULL);
+}
+
+// =============================== 2.5625 bpw
+
+static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) {
+
+ const int gindex = iq2_data_index(GGML_TYPE_IQ2_S);
+
+ const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
+ const int * kmap_q2xs = iq2_data[gindex].map;
+ const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+ GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
+ GGML_ASSERT(n%QK_K == 0);
+
+ const int kMaxQ = 3;
+
+ const int64_t nbl = n/QK_K;
+
+ block_iq2_s * y = vy;
+
+ float scales[QK_K/16];
+ float weight[16];
+ float xval[16];
+ int8_t L[16];
+ int8_t Laux[16];
+ float waux[16];
+ bool is_on_grid[2];
+ bool is_on_grid_aux[2];
+ uint8_t block_signs[2];
+
+ for (int ibl = 0; ibl < nbl; ++ibl) {
+
+ memset(&y[ibl], 0, sizeof(block_iq2_s));
+ y[ibl].d = GGML_FP32_TO_FP16(0.f);
+
+ float max_scale = 0;
+
+ const float * xbl = x + QK_K*ibl;
+ float sumx2 = 0;
+ for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+ float sigma2 = 2*sumx2/QK_K;
+
+ for (int ib = 0; ib < QK_K/16; ++ib) {
+ const float * xb = xbl + 16*ib;
+ if (quant_weights) {
+ const float * qw = quant_weights + QK_K*ibl + 16*ib;
+ for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+ } else {
+ for (int i = 0; i < 16; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i];
+ }
+ for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]);
+ for (int k = 0; k < 2; ++k) {
+ uint8_t s = 0;
+ for (int i = 0; i < 8; ++i) {
+ if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+ else {
+ xval[8*k + i] = -xb[8*k + i]; s |= (1 << i);
+ }
+ }
+ block_signs[k] = s;
+ }
+ float max = xval[0];
+ for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
+ if (max < GROUP_MAX_EPS_IQ2_S) {
+ scales[ib] = 0;
+ continue;
+ }
+ float best = 0;
+ float scale = max/(2*kMaxQ-1);
+ is_on_grid[0] = is_on_grid[1] = true;
+ for (int is = -9; is <= 9; ++is) {
+ float id = (2*kMaxQ-1+is*0.1f)/max;
+ float this_scale = 1/id;
+ for (int k = 0; k < 2; ++k) {
+ for (int i = 0; i < 8; ++i) {
+ int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+ Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
+ }
+ uint16_t u = 0;
+ for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
+ int grid_index = kmap_q2xs[u];
+ is_on_grid_aux[k] = true;
+ if (grid_index < 0) {
+ is_on_grid_aux[k] = false;
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+ grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
+ }
+ }
+ float sumqx = 0, sumq2 = 0;
+ for (int i = 0; i < 16; ++i) {
+ float w = weight[i];
+ float q = 2*Laux[i] + 1;
+ sumqx += w*xval[i]*q;
+ sumq2 += w*q*q;
+ }
+ if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+ scale = sumqx/sumq2; best = scale*sumqx;
+ for (int i = 0; i < 16; ++i) L[i] = Laux[i];
+ for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k];
+ }
+ }
+ int n_not_ongrid = 0;
+ for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
+ if (n_not_ongrid > 0 && scale > 0) {
+ float id = 1/scale;
+ for (int k = 0; k < 2; ++k) {
+ if (is_on_grid[k]) continue;
+ uint16_t u = 0;
+ for (int i = 0; i < 8; ++i) {
+ int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+ l = MAX(0, MIN(kMaxQ-1, l));
+ u |= (l << 2*i);
+ L[8*k + i] = l;
+ }
+ int grid_index = kmap_q2xs[u];
+ if (grid_index < 0) {
+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+ grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
+ }
+ }
+ float sumqx = 0, sumq2 = 0;
+ for (int i = 0; i < 16; ++i) {
+ float w = weight[i];
+ float q = 2*L[i] + 1;
+ sumqx += w*xval[i]*q;
+ sumq2 += w*q*q;
+ }
+ if (sumq2 > 0) scale = sumqx/sumq2;
+ }
+ if (scale < 0) {
+ scale = -scale;
+ for (int k = 0; k < 2; ++k) block_signs[k] = ~block_signs[k];
+ }
+ for (int k = 0; k < 2; ++k) {
+ uint16_t u = 0;
+ for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
+ int grid_index = kmap_q2xs[u];
+ if (grid_index < 0) {
+ printf("Oops: found point %u not on grid:", u);
+ for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
+ printf("\n");
+ GGML_ASSERT(false);
+ }
+ const int i8 = 2*ib + k;
+ y[ibl].qs[i8] = grid_index & 255;
+ y[ibl].qh[i8/4] |= ((grid_index >> 8) << 2*(i8%4));
+ y[ibl].qs[QK_K/8 + i8] = block_signs[k];
+ }
+ GGML_ASSERT(scale >= 0);
+ scales[ib] = scale;
+ max_scale = MAX(max_scale, scale);
+ }
+
+ if (!max_scale) {
+ continue;
+ }
+
+ float d = max_scale/31;
+ y[ibl].d = GGML_FP32_TO_FP16(d * 0.9875f);
+ float id = 1/d;
+ for (int ib = 0; ib < QK_K/16; ++ib) {
+ int l = nearest_int(0.5f*(id*scales[ib]-1));
+ l = MAX(0, MIN(15, l));
+ if (ib%2 == 0) y[ibl].scales[ib/2] = l;
+ else y[ibl].scales[ib/2] |= (l << 4);
+ }
+ }
+}
+
+size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ GGML_ASSERT(n_per_row%QK_K == 0);
+ int64_t nblock = n_per_row/QK_K;
+ char * qrow = (char *)dst;
+ for (int64_t row = 0; row < nrow; ++row) {
+ quantize_row_iq2_s_impl(src, qrow, n_per_row, quant_weights);
+ src += n_per_row;
+ qrow += nblock*sizeof(block_iq2_s);
+ }
+ return nrow * nblock * sizeof(block_iq2_s);
+}
+
+void quantize_row_iq2_s_ref(const float * restrict x, block_iq2_s * restrict y, int64_t k) {
+ assert(k % QK_K == 0);
+ quantize_iq2_s(x, y, 1, k, NULL);
+}
+
+void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) {
+ assert(k % QK_K == 0);
+ block_iq2_s * restrict y = vy;
+ quantize_row_iq2_s_ref(x, y, k);
+}
+
+static bool validate_float(float f, size_t i) {
+ if (isinf(f)) {
+ fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
+ return false;
+ }
+
+ if (isnan(f)) {
+ fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
+ return false;
+ }
+
+ return true;
+}
+
+static bool isinf_fp16(ggml_fp16_t f) {
+ return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) == 0;
+}
+
+static bool isnan_fp16(ggml_fp16_t f) {
+ return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0;
+}
+
+static bool validate_fp16(ggml_fp16_t f, size_t i) {
+ if (isinf_fp16(f)) {
+ fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
+ return false;
+ }
+
+ if (isnan_fp16(f)) {
+ fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
+ return false;
+ }
+
+ return true;
+}
+
+#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
+ const type * q = (const type *) (data); \
+ for (size_t i = 0; i < (nb); ++i) { \
+ if (!validate_fp16(q[i].d, i)) { \
+ return false; \
+ } \
+ }
+
+#define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \
+ const type * q = (const type *) (data); \
+ for (size_t i = 0; i < (nb); ++i) { \
+ if (!validate_fp16(q[i].d, i) || !validate_fp16(q[i].m, i)) { \
+ return false; \
+ } \
+ }
+
+#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
+ const type * q = (const type *) (data); \
+ for (size_t i = 0; i < (nb); ++i) { \
+ for (size_t j = 0; j < (nr); ++j) { \
+ if (!validate_fp16(q[i].d[j], i)) { \
+ return false; \
+ } \
+ } \
+ }
+
+bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {
+ if (type < 0 || type >= GGML_TYPE_COUNT) {
+ fprintf(stderr, "%s: invalid type %d\n", __func__, type);
+ return false;
+ }
+
+ if (nbytes % ggml_type_size(type) != 0) {
+ fprintf(stderr, "%s: invalid size %zu for type %d\n", __func__, nbytes, type);
+ return false;
+ }
+
+ const size_t nb = nbytes/ggml_type_size(type);
+
+ switch (type) {
+ case GGML_TYPE_BF16:
+ {
+ int nans = 0;
+ int infs = 0;
+ const unsigned short * f = (const unsigned short *) data;
+ for (size_t i = 0; i < nb; ++i) {
+ nans += (f[i] & 0x7fff) > 0x7f80;
+ infs += (f[i] & 0x7fff) == 0x7f80;
+ }
+ if (nans) {
+ fprintf(stderr, "%s: found %d NaNs in row of %zu BF16 values\n", __func__, nans, nb);
+ return false;
+ }
+ if (infs) {
+ fprintf(stderr, "%s: found %d infinities in row of %zu BF16 values\n", __func__, infs, nb);
+ return false;
+ }
+ } break;
+ case GGML_TYPE_F16:
+ {
+ const ggml_fp16_t * f = (const ggml_fp16_t *) data;
+ size_t i = 0;
+#if defined(__AVX2__)
+ for (; i + 15 < nb; i += 16) {
+ __m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
+ __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00));
+ __m256i cmp = _mm256_cmpeq_epi16(vexp, _mm256_set1_epi16(0x7c00));
+ int mask = _mm256_movemask_epi8(cmp);
+ if (mask) {
+ for (size_t j = 0; j < 16; ++j) {
+ if (!validate_fp16(f[i + j], i + j)) {
+ return false;
+ }
+ }
+ GGML_UNREACHABLE();
+ }
+ }
+#elif defined(__ARM_NEON)
+ for (; i + 7 < nb; i += 8) {
+ uint16x8_t v = vld1q_u16(f + i);
+ uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00));
+ uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00));
+ uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0);
+ if (mask) {
+ for (size_t j = 0; j < 8; ++j) {
+ if (!validate_fp16(f[i + j], i + j)) {
+ return false;
+ }
+ }
+ GGML_UNREACHABLE();
+ }
+ }
+#endif
+ for (; i < nb; ++i) {
+ if (!validate_fp16(f[i], i)) {
+ return false;
+ }
+ }
+ } break;
+ case GGML_TYPE_F32:
+ {
+ const float * f = (const float *) data;
+ size_t i = 0;
+#if defined(__AVX2__)
+ for (; i + 7 < nb; i += 8) {
+ __m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
+ __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000));
+ __m256i cmp = _mm256_cmpeq_epi32(vexp, _mm256_set1_epi32(0x7f800000));
+ int mask = _mm256_movemask_epi8(cmp);
+ if (mask) {
+ for (size_t j = 0; j < 8; ++j) {
+ if (!validate_float(f[i + j], i + j)) {
+ return false;
+ }
+ }
+ GGML_UNREACHABLE();
+ }
+ }
+#elif defined(__ARM_NEON)
+ for (; i + 3 < nb; i += 4) {
+ uint32x4_t v = vld1q_u32((const uint32_t *)f + i);
+ uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000));
+ uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000));
+ uint64_t mask = vget_lane_u64(vreinterpret_u64_u16(vshrn_n_u32(cmp, 8)), 0);
+ if (mask) {
+ for (size_t j = 0; j < 4; ++j) {
+ if (!validate_float(f[i + j], i + j)) {
+ return false;
+ }
+ }
+ GGML_UNREACHABLE();
+ }
+ }
+#endif
+ for (; i < nb; ++i) {
+ if (!validate_float(f[i], i)) {
+ return false;
+ }
+ }
+ } break;
+ case GGML_TYPE_F64:
+ {
+ const double * f = (const double *) data;
+ for (size_t i = 0; i < nb; ++i) {
+ if (!validate_float(f[i], i)) {
+ return false;
+ }
+ }
+ } break;
+ case GGML_TYPE_Q4_0:
+ {
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb);
+ } break;
+ case GGML_TYPE_Q4_1:
+ {
+ VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_1, data, nb, d, m);
+ } break;
+ case GGML_TYPE_Q5_0:
+ {
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_0, data, nb);
+ } break;
+ case GGML_TYPE_Q5_1:
+ {
+ VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_1, data, nb, d, m);
+ } break;
+ case GGML_TYPE_Q8_0:
+ {
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
+ } break;
+ case GGML_TYPE_Q2_K:
+ {
+ VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
+ } break;
+ case GGML_TYPE_Q3_K:
+ {
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_q3_K, data, nb);
+ } break;
+ case GGML_TYPE_Q4_K:
+ {
+ VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin);
+ } break;
+ case GGML_TYPE_Q5_K:
+ {
+ VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin);
+ } break;
+ case GGML_TYPE_Q6_K:
+ {
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_q6_K, data, nb);
+ } break;
+ case GGML_TYPE_Q8_K:
+ {
+ const block_q8_K * q = (const block_q8_K *) data;
+ for (size_t i = 0; i < nb; ++i) {
+ if (!validate_float(q[i].d, i)) {
+ return false;
+ }
+ }
+ } break;
+ case GGML_TYPE_IQ1_S:
+ {
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
+ } break;
+ case GGML_TYPE_IQ1_M:
+ {
+ const block_iq1_m * q = (const block_iq1_m *) data;
+ for (size_t i = 0; i < nb; ++i) {
+ iq1m_scale_t scale;
+ const uint16_t * sc = (const uint16_t *)q[i].scales;
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+ if (!validate_fp16(scale.f16, i)) {
+ return false;
+ }
+ }
+ } break;
+ case GGML_TYPE_IQ2_XXS:
+ {
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xxs, data, nb);
+ } break;
+ case GGML_TYPE_IQ2_XS:
+ {
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xs, data, nb);
+ } break;
+ case GGML_TYPE_IQ2_S:
+ {
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_s, data, nb);
+ } break;
+ case GGML_TYPE_IQ3_XXS:
+ {
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_xxs, data, nb);
+ } break;
+
+ case GGML_TYPE_IQ3_S:
+ {
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb);
+ } break;
+ case GGML_TYPE_IQ4_XS:
+ {
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb);
+ } break;
+ case GGML_TYPE_IQ4_NL:
+ {
+ VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
+ } break;
+ case GGML_TYPE_Q4_0_4_4:
+ case GGML_TYPE_Q4_0_4_8:
+ {
+ VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4);
+ } break;
+ case GGML_TYPE_Q4_0_8_8:
+ {
+ VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8);
+ } break;
+
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_I64:
+ case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
+ // nothing to validate
+ break;
+ default:
+ {
+ fprintf(stderr, "%s: invalid type %d\n", __func__, type);
+ return false;
+ }
+ }
+
+ return true;
+}
diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h
new file mode 100644
index 00000000..91063633
--- /dev/null
+++ b/ggml/src/ggml-quants.h
@@ -0,0 +1,151 @@
+//
+// Copyright (C) 2023-2024 The ggml authors
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+
+#include "ggml.h"
+
+// GGML internal header
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Quantization
+void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq1_bn_ref (const float * GGML_RESTRICT x, block_iq1_bn * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq2_bn_ref (const float * GGML_RESTRICT x, block_iq2_bn * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq1_bn (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq2_bn (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+// Dequantization
+void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+//void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
+void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
+void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq1_m (const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq1_bn (const block_iq1_bn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+void dequantize_row_iq2_bn (const block_iq2_bn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
+// Dot product
+void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq1_bn_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq2_bn_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
+size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq2_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq1_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq1_m (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq1_bn (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_iq2_bn (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+
+size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+
+void iq2xs_init_impl(enum ggml_type type);
+void iq2xs_free_impl(enum ggml_type type);
+void iq3xs_init_impl(int grid_size);
+void iq3xs_free_impl(int grid_size);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/src/ggml-rpc.cpp b/ggml/src/ggml-rpc.cpp
new file mode 100644
index 00000000..b01ad267
--- /dev/null
+++ b/ggml/src/ggml-rpc.cpp
@@ -0,0 +1,1178 @@
+#include "ggml-rpc.h"
+#include "ggml.h"
+#include "ggml-backend-impl.h"
+
+#include <cinttypes>
+#include <string>
+#include <vector>
+#include <memory>
+#include <mutex>
+#include <unordered_map>
+#include <unordered_set>
+#ifdef _WIN32
+# define WIN32_LEAN_AND_MEAN
+# ifndef NOMINMAX
+# define NOMINMAX
+# endif
+# include <windows.h>
+# include <winsock2.h>
+#else
+# include <arpa/inet.h>
+# include <sys/socket.h>
+# include <sys/types.h>
+# include <netinet/in.h>
+# include <netinet/tcp.h>
+# include <netdb.h>
+# include <unistd.h>
+#endif
+#include <string.h>
+
+#define UNUSED GGML_UNUSED
+
+#define GGML_DEBUG 0
+#if (GGML_DEBUG >= 1)
+#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG(...)
+#endif
+
+#ifdef _WIN32
+typedef SOCKET sockfd_t;
+using ssize_t = __int64;
+#else
+typedef int sockfd_t;
+#endif
+
+// cross-platform socket
+struct socket_t {
+ sockfd_t fd;
+ socket_t(sockfd_t fd) : fd(fd) {}
+ ~socket_t() {
+ GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
+#ifdef _WIN32
+ closesocket(this->fd);
+#else
+ close(this->fd);
+#endif
+ }
+};
+
+// ggml_tensor is serialized into rpc_tensor
+#pragma pack(push, 1)
+struct rpc_tensor {
+ uint64_t id;
+ uint32_t type;
+ uint64_t buffer;
+ uint32_t ne[GGML_MAX_DIMS];
+ uint32_t nb[GGML_MAX_DIMS];
+ uint32_t op;
+ int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
+ int32_t flags;
+ uint64_t src[GGML_MAX_SRC];
+ uint64_t view_src;
+ uint64_t view_offs;
+ uint64_t data;
+ char name[GGML_MAX_NAME];
+
+ char padding[4];
+};
+#pragma pack(pop)
+
+static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
+
+// RPC commands
+enum rpc_cmd {
+ ALLOC_BUFFER = 0,
+ GET_ALIGNMENT,
+ GET_MAX_SIZE,
+ BUFFER_GET_BASE,
+ FREE_BUFFER,
+ BUFFER_CLEAR,
+ SET_TENSOR,
+ GET_TENSOR,
+ COPY_TENSOR,
+ GRAPH_COMPUTE,
+ GET_DEVICE_MEMORY,
+};
+
+// RPC data structures
+
+static ggml_guid_t ggml_backend_rpc_guid() {
+ static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
+ return &guid;
+}
+
+struct ggml_backend_rpc_buffer_type_context {
+ std::string endpoint;
+ std::string name;
+ size_t alignment;
+ size_t max_size;
+};
+
+struct ggml_backend_rpc_context {
+ std::string endpoint;
+ std::string name;
+};
+
+struct ggml_backend_rpc_buffer_context {
+ std::shared_ptr<socket_t> sock;
+ std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
+ uint64_t remote_ptr;
+ std::string name;
+};
+
+// RPC helper functions
+
+static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
+#ifdef _WIN32
+ if (fd == INVALID_SOCKET) {
+ return nullptr;
+ }
+#else
+ if (fd < 0) {
+ return nullptr;
+ }
+#endif
+ return std::make_shared<socket_t>(fd);
+}
+
+static bool set_no_delay(sockfd_t sockfd) {
+ int flag = 1;
+ // set TCP_NODELAY to disable Nagle's algorithm
+ int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
+ return ret == 0;
+}
+
+static bool set_reuse_addr(sockfd_t sockfd) {
+ int flag = 1;
+ int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
+ return ret == 0;
+}
+
+static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
+ struct sockaddr_in addr;
+ auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
+ auto sock_ptr = make_socket(sockfd);
+ if (sock_ptr == nullptr) {
+ return nullptr;
+ }
+ if (!set_no_delay(sockfd)) {
+ fprintf(stderr, "Failed to set TCP_NODELAY\n");
+ return nullptr;
+ }
+ addr.sin_family = AF_INET;
+ addr.sin_port = htons(port);
+ struct hostent * server = gethostbyname(host);
+ if (server == NULL) {
+ fprintf(stderr, "Cannot resolve host '%s'\n", host);
+ return nullptr;
+ }
+ memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
+ if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
+ return nullptr;
+ }
+ return sock_ptr;
+}
+
+static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
+ auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
+ auto client_socket = make_socket(client_socket_fd);
+ if (client_socket == nullptr) {
+ return nullptr;
+ }
+ if (!set_no_delay(client_socket_fd)) {
+ fprintf(stderr, "Failed to set TCP_NODELAY\n");
+ return nullptr;
+ }
+ return client_socket;
+}
+
+static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
+ auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
+ auto sock = make_socket(sockfd);
+ if (sock == nullptr) {
+ return nullptr;
+ }
+ if (!set_reuse_addr(sockfd)) {
+ fprintf(stderr, "Failed to set SO_REUSEADDR\n");
+ return nullptr;
+ }
+ struct sockaddr_in serv_addr;
+ serv_addr.sin_family = AF_INET;
+ serv_addr.sin_addr.s_addr = inet_addr(host);
+ serv_addr.sin_port = htons(port);
+
+ if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
+ return nullptr;
+ }
+ if (listen(sockfd, 1) < 0) {
+ return nullptr;
+ }
+ return sock;
+}
+
+static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
+ size_t bytes_sent = 0;
+ while (bytes_sent < size) {
+ ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0);
+ if (n < 0) {
+ return false;
+ }
+ bytes_sent += n;
+ }
+ return true;
+}
+
+static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
+ size_t bytes_recv = 0;
+ while (bytes_recv < size) {
+ ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0);
+ if (n <= 0) {
+ return false;
+ }
+ bytes_recv += n;
+ }
+ return true;
+}
+
+static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
+ size_t pos = endpoint.find(':');
+ if (pos == std::string::npos) {
+ return false;
+ }
+ host = endpoint.substr(0, pos);
+ port = std::stoi(endpoint.substr(pos + 1));
+ return true;
+}
+
+// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
+// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
+static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
+ uint8_t cmd_byte = cmd;
+ if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
+ return false;
+ }
+ uint64_t input_size = input.size();
+ if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
+ return false;
+ }
+ if (!send_data(sock->fd, input.data(), input.size())) {
+ return false;
+ }
+ uint64_t output_size;
+ if (!recv_data(sock->fd, &output_size, sizeof(output_size))) {
+ return false;
+ }
+ if (output_size == 0) {
+ output.clear();
+ return true;
+ }
+ output.resize(output_size);
+ if (!recv_data(sock->fd, output.data(), output_size)) {
+ return false;
+ }
+ return true;
+}
+
+// RPC client-side implementation
+
+static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+ static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
+ static bool initialized = false;
+
+ auto it = sockets.find(endpoint);
+ if (it != sockets.end()) {
+ if (auto sock = it->second.lock()) {
+ return sock;
+ }
+ }
+ std::string host;
+ int port;
+ if (!parse_endpoint(endpoint, host, port)) {
+ return nullptr;
+ }
+#ifdef _WIN32
+ if (!initialized) {
+ WSADATA wsaData;
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
+ if (res != 0) {
+ return nullptr;
+ }
+ initialized = true;
+ }
+#else
+ UNUSED(initialized);
+#endif
+ auto sock = socket_connect(host.c_str(), port);
+ if (sock == nullptr) {
+ return nullptr;
+ }
+ GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
+ sockets[endpoint] = sock;
+ return sock;
+}
+
+GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
+ return ctx->name.c_str();
+}
+
+GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
+ // input serialization format: | remote_ptr (8 bytes) |
+ std::vector<uint8_t> input(sizeof(uint64_t), 0);
+ uint64_t remote_ptr = ctx->remote_ptr;
+ memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
+ std::vector<uint8_t> output;
+ bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
+ GGML_ASSERT(status);
+ GGML_ASSERT(output.empty());
+ delete ctx;
+}
+
+GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
+ if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
+ return ctx->base_cache[buffer];
+ }
+ // input serialization format: | remote_ptr (8 bytes) |
+ std::vector<uint8_t> input(sizeof(uint64_t), 0);
+ uint64_t remote_ptr = ctx->remote_ptr;
+ memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
+ std::vector<uint8_t> output;
+ bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
+ GGML_ASSERT(status);
+ GGML_ASSERT(output.size() == sizeof(uint64_t));
+ // output serialization format: | base_ptr (8 bytes) |
+ uint64_t base_ptr;
+ memcpy(&base_ptr, output.data(), sizeof(base_ptr));
+ void * base = reinterpret_cast<void *>(base_ptr);
+ ctx->base_cache[buffer] = base;
+ return base;
+}
+
+static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
+ rpc_tensor result;
+ result.id = reinterpret_cast<uint64_t>(tensor);
+ result.type = tensor->type;
+ if (tensor->buffer) {
+ ggml_backend_buffer_t buffer = tensor->buffer;
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
+ result.buffer = ctx->remote_ptr;
+ } else {
+ result.buffer = 0;
+ }
+ for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
+ result.ne[i] = tensor->ne[i];
+ result.nb[i] = tensor->nb[i];
+ }
+ result.op = tensor->op;
+ for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
+ result.op_params[i] = tensor->op_params[i];
+ }
+ result.flags = tensor->flags;
+ for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
+ result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
+ }
+ result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
+ result.view_offs = tensor->view_offs;
+ result.data = reinterpret_cast<uint64_t>(tensor->data);
+ snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
+ return result;
+}
+
+GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+ UNUSED(buffer);
+ if (ggml_is_quantized(tensor->type)) {
+ // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
+ GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
+ }
+}
+
+GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
+ // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
+ size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
+ std::vector<uint8_t> input(input_size, 0);
+ rpc_tensor rpc_tensor = serialize_tensor(tensor);
+ memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
+ memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
+ memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
+ std::vector<uint8_t> output;
+ bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
+ GGML_ASSERT(status);
+}
+
+GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
+ // input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
+ int input_size = sizeof(rpc_tensor) + 2*sizeof(uint64_t);
+ std::vector<uint8_t> input(input_size, 0);
+ rpc_tensor rpc_tensor = serialize_tensor(tensor);
+ memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
+ memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
+ memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
+ std::vector<uint8_t> output;
+ bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
+ GGML_ASSERT(status);
+ GGML_ASSERT(output.size() == size);
+ // output serialization format: | data (size bytes) |
+ memcpy(data, output.data(), size);
+}
+
+GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
+ // check if src and dst are on the same server
+ ggml_backend_buffer_t src_buffer = src->buffer;
+ ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
+ ggml_backend_buffer_t dst_buffer = dst->buffer;
+ ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
+ if (src_ctx->sock != dst_ctx->sock) {
+ return false;
+ }
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
+ // input serialization format: | rpc_tensor src | rpc_tensor dst |
+ int input_size = 2*sizeof(rpc_tensor);
+ std::vector<uint8_t> input(input_size, 0);
+ rpc_tensor rpc_src = serialize_tensor(src);
+ rpc_tensor rpc_dst = serialize_tensor(dst);
+ memcpy(input.data(), &rpc_src, sizeof(rpc_src));
+ memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
+ std::vector<uint8_t> output;
+ bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
+ GGML_ASSERT(status);
+ // output serialization format: | result (1 byte) |
+ GGML_ASSERT(output.size() == 1);
+ return output[0];
+}
+
+GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
+ // serialization format: | bufptr (8 bytes) | value (1 byte) |
+ int input_size = sizeof(uint64_t) + sizeof(uint8_t);
+ std::vector<uint8_t> input(input_size, 0);
+ memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
+ memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
+ std::vector<uint8_t> output;
+ bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
+ GGML_ASSERT(status);
+}
+
+static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
+ /* .get_name = */ ggml_backend_rpc_buffer_get_name,
+ /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_rpc_buffer_get_base,
+ /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
+ /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
+ /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
+ /* .clear = */ ggml_backend_rpc_buffer_clear,
+ /* .reset = */ NULL,
+};
+
+GGML_CALL static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
+ return buft_ctx->name.c_str();
+}
+
+GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
+ // input serialization format: | size (8 bytes) |
+ int input_size = sizeof(uint64_t);
+ std::vector<uint8_t> input(input_size, 0);
+ memcpy(input.data(), &size, sizeof(size));
+ std::vector<uint8_t> output;
+ auto sock = get_socket(buft_ctx->endpoint);
+ bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
+ GGML_ASSERT(status);
+ GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
+ // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
+ uint64_t remote_ptr;
+ memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
+ size_t remote_size;
+ memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
+ if (remote_ptr != 0) {
+ ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
+ ggml_backend_rpc_buffer_interface,
+ new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
+ remote_size);
+ return buffer;
+ } else {
+ return nullptr;
+ }
+}
+
+static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
+ // input serialization format: | 0 bytes |
+ std::vector<uint8_t> input;
+ std::vector<uint8_t> output;
+ bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
+ GGML_ASSERT(status);
+ GGML_ASSERT(output.size() == sizeof(uint64_t));
+ // output serialization format: | alignment (8 bytes) |
+ uint64_t alignment;
+ memcpy(&alignment, output.data(), sizeof(alignment));
+ return alignment;
+}
+
+GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
+ return buft_ctx->alignment;
+}
+
+static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
+ // input serialization format: | 0 bytes |
+ std::vector<uint8_t> input;
+ std::vector<uint8_t> output;
+ bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
+ GGML_ASSERT(status);
+ GGML_ASSERT(output.size() == sizeof(uint64_t));
+ // output serialization format: | max_size (8 bytes) |
+ uint64_t max_size;
+ memcpy(&max_size, output.data(), sizeof(max_size));
+ return max_size;
+}
+
+GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
+ return buft_ctx->max_size;
+}
+
+GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+ UNUSED(buft);
+ return ggml_nbytes(tensor);
+}
+
+static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
+ /* .get_name = */ ggml_backend_rpc_buffer_type_name,
+ /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
+ /* .get_max_size = */ ggml_backend_rpc_get_max_size,
+ /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
+ /* .is_host = */ NULL,
+};
+
+GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
+ ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
+
+ return rpc_ctx->name.c_str();
+}
+
+GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
+ ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
+ delete rpc_ctx;
+ delete backend;
+}
+
+GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
+ ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
+ return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
+}
+
+GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
+ UNUSED(backend);
+ // this is no-op because we don't have any async operations
+}
+
+static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
+ if (tensor == nullptr) {
+ return;
+ }
+ if (visited.find(tensor) != visited.end()) {
+ return;
+ }
+ visited.insert(tensor);
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ add_tensor(tensor->src[i], tensors, visited);
+ }
+ add_tensor(tensor->view_src, tensors, visited);
+ tensors.push_back(serialize_tensor(tensor));
+}
+
+static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
+ uint32_t n_nodes = cgraph->n_nodes;
+ std::vector<rpc_tensor> tensors;
+ std::unordered_set<ggml_tensor*> visited;
+ for (uint32_t i = 0; i < n_nodes; i++) {
+ add_tensor(cgraph->nodes[i], tensors, visited);
+ }
+ // serialization format:
+ // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
+ uint32_t n_tensors = tensors.size();
+ int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
+ output.resize(output_size, 0);
+ memcpy(output.data(), &n_nodes, sizeof(n_nodes));
+ for (uint32_t i = 0; i < n_nodes; i++) {
+ memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
+ }
+ uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
+ *out_ntensors = n_tensors;
+ rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
+ memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
+}
+
+GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+ ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
+ std::vector<uint8_t> input;
+ serialize_graph(cgraph, input);
+ std::vector<uint8_t> output;
+ auto sock = get_socket(rpc_ctx->endpoint);
+ bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
+ GGML_ASSERT(status);
+ GGML_ASSERT(output.size() == 1);
+ return (enum ggml_status)output[0];
+}
+
+GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
+ UNUSED(backend);
+ UNUSED(op);
+ //TODO: call the remote backend and cache the results
+ return true;
+}
+
+GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+ if (buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
+ return false;
+ }
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
+ ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
+ return buft_ctx->endpoint == rpc_ctx->endpoint;
+}
+
+static ggml_backend_i ggml_backend_rpc_interface = {
+ /* .get_name = */ ggml_backend_rpc_name,
+ /* .free = */ ggml_backend_rpc_free,
+ /* .get_default_buffer_type = */ ggml_backend_rpc_get_default_buffer_type,
+ /* .set_tensor_async = */ NULL,
+ /* .get_tensor_async = */ NULL,
+ /* .cpy_tensor_async = */ NULL,
+ /* .synchronize = */ ggml_backend_rpc_synchronize,
+ /* .graph_plan_create = */ NULL,
+ /* .graph_plan_free = */ NULL,
+ /* .graph_plan_update = */ NULL,
+ /* .graph_plan_compute = */ NULL,
+ /* .graph_compute = */ ggml_backend_rpc_graph_compute,
+ /* .supports_op = */ ggml_backend_rpc_supports_op,
+ /* .supports_buft = */ ggml_backend_rpc_supports_buft,
+ /* .offload_op = */ NULL,
+ /* .event_new = */ NULL,
+ /* .event_free = */ NULL,
+ /* .event_record = */ NULL,
+ /* .event_wait = */ NULL,
+ /* .event_synchronize = */ NULL,
+};
+
+GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+ // NOTE: buffer types are allocated and never freed; this is by design
+ static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
+ auto it = buft_map.find(endpoint);
+ if (it != buft_map.end()) {
+ return it->second;
+ }
+ auto sock = get_socket(endpoint);
+ if (sock == nullptr) {
+ return nullptr;
+ }
+ size_t alignment = get_alignment(sock);
+ size_t max_size = get_max_size(sock);
+ ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
+ /* .endpoint = */ endpoint,
+ /* .name = */ "RPC[" + std::string(endpoint) + "]",
+ /* .alignment = */ alignment,
+ /* .max_size = */ max_size
+ };
+
+ ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
+ /* .iface = */ ggml_backend_rpc_buffer_type_interface,
+ /* .context = */ buft_ctx
+ };
+ buft_map[endpoint] = buft;
+ return buft;
+}
+
+GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
+ ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
+ /* .endpoint = */ endpoint,
+ /* .name = */ "RPC[" + std::string(endpoint) + "]",
+ };
+
+ ggml_backend_t backend = new ggml_backend {
+ /* .guid = */ ggml_backend_rpc_guid(),
+ /* .interface = */ ggml_backend_rpc_interface,
+ /* .context = */ ctx
+ };
+ return backend;
+}
+
+GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
+}
+
+static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
+ // input serialization format: | 0 bytes |
+ std::vector<uint8_t> input;
+ std::vector<uint8_t> output;
+ bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
+ GGML_ASSERT(status);
+ GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
+ // output serialization format: | free (8 bytes) | total (8 bytes) |
+ uint64_t free_mem;
+ memcpy(&free_mem, output.data(), sizeof(free_mem));
+ uint64_t total_mem;
+ memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem));
+ *free = free_mem;
+ *total = total_mem;
+}
+
+GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
+ auto sock = get_socket(endpoint);
+ if (sock == nullptr) {
+ *free = 0;
+ *total = 0;
+ return;
+ }
+ get_device_memory(sock, free, total);
+}
+
+// RPC server-side implementation
+
+class rpc_server {
+public:
+ rpc_server(ggml_backend_t backend) : backend(backend) {}
+ ~rpc_server();
+
+ bool alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
+ void get_alignment(std::vector<uint8_t> & output);
+ void get_max_size(std::vector<uint8_t> & output);
+ bool buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
+ bool free_buffer(const std::vector<uint8_t> & input);
+ bool buffer_clear(const std::vector<uint8_t> & input);
+ bool set_tensor(const std::vector<uint8_t> & input);
+ bool get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
+ bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
+ bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
+
+private:
+ ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
+ ggml_tensor * create_node(uint64_t id,
+ struct ggml_context * ctx,
+ const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
+ std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
+
+
+ ggml_backend_t backend;
+ std::unordered_set<ggml_backend_buffer_t> buffers;
+};
+
+bool rpc_server::alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
+ // input serialization format: | size (8 bytes) |
+ if (input.size() != sizeof(uint64_t)) {
+ return false;
+ }
+ uint64_t size;
+ memcpy(&size, input.data(), sizeof(size));
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
+ ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
+ uint64_t remote_ptr = 0;
+ uint64_t remote_size = 0;
+ if (buffer != nullptr) {
+ remote_ptr = reinterpret_cast<uint64_t>(buffer);
+ remote_size = buffer->size;
+ GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
+ buffers.insert(buffer);
+ } else {
+ GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size);
+ }
+ // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
+ output.resize(2*sizeof(uint64_t), 0);
+ memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
+ memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
+ return true;
+}
+
+void rpc_server::get_alignment(std::vector<uint8_t> & output) {
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
+ size_t alignment = ggml_backend_buft_get_alignment(buft);
+ GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
+ // output serialization format: | alignment (8 bytes) |
+ output.resize(sizeof(uint64_t), 0);
+ memcpy(output.data(), &alignment, sizeof(alignment));
+}
+
+void rpc_server::get_max_size(std::vector<uint8_t> & output) {
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
+ size_t max_size = ggml_backend_buft_get_max_size(buft);
+ GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
+ // output serialization format: | max_size (8 bytes) |
+ output.resize(sizeof(uint64_t), 0);
+ memcpy(output.data(), &max_size, sizeof(max_size));
+}
+
+bool rpc_server::buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
+ // input serialization format: | remote_ptr (8 bytes) |
+ if (input.size() != sizeof(uint64_t)) {
+ return false;
+ }
+ uint64_t remote_ptr;
+ memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
+ GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
+ ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
+ if (buffers.find(buffer) == buffers.end()) {
+ GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
+ return false;
+ }
+ void * base = ggml_backend_buffer_get_base(buffer);
+ // output serialization format: | base_ptr (8 bytes) |
+ uint64_t base_ptr = reinterpret_cast<uint64_t>(base);
+ output.resize(sizeof(uint64_t), 0);
+ memcpy(output.data(), &base_ptr, sizeof(base_ptr));
+ return true;
+}
+
+bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
+ // input serialization format: | remote_ptr (8 bytes) |
+ if (input.size() != sizeof(uint64_t)) {
+ return false;
+ }
+ uint64_t remote_ptr;
+ memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
+ GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
+ ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
+ if (buffers.find(buffer) == buffers.end()) {
+ GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
+ return false;
+ }
+ ggml_backend_buffer_free(buffer);
+ buffers.erase(buffer);
+ return true;
+}
+
+bool rpc_server::buffer_clear(const std::vector<uint8_t> & input) {
+ // input serialization format: | remote_ptr (8 bytes) | value (1 byte) |
+ if (input.size() != sizeof(uint64_t) + sizeof(uint8_t)) {
+ return false;
+ }
+ uint64_t remote_ptr;
+ memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
+ uint8_t value;
+ memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
+ GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
+ ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
+ if (buffers.find(buffer) == buffers.end()) {
+ GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
+ return false;
+ }
+ ggml_backend_buffer_clear(buffer, value);
+ return true;
+}
+
+ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
+ ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
+ tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
+ for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
+ result->nb[i] = tensor->nb[i];
+ }
+ result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
+ if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
+ return nullptr;
+ }
+ result->op = (ggml_op) tensor->op;
+ for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
+ result->op_params[i] = tensor->op_params[i];
+ }
+ result->flags = tensor->flags;
+ result->data = reinterpret_cast<void *>(tensor->data);
+ ggml_set_name(result, tensor->name);
+ return result;
+}
+
+
+bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
+ // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
+ if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
+ return false;
+ }
+ const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
+ uint64_t offset;
+ memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
+ size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
+
+ struct ggml_init_params params {
+ /*.mem_size =*/ ggml_tensor_overhead(),
+ /*.mem_buffer =*/ NULL,
+ /*.no_alloc =*/ true,
+ };
+ struct ggml_context * ctx = ggml_init(params);
+ ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
+ if (tensor == nullptr) {
+ GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
+ ggml_free(ctx);
+ return false;
+ }
+ GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
+ const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
+ ggml_backend_tensor_set(tensor, data, offset, size);
+ ggml_free(ctx);
+ return true;
+}
+
+bool rpc_server::get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
+ // serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
+ if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) {
+ return false;
+ }
+ const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
+ uint64_t offset;
+ memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
+ uint64_t size;
+ memcpy(&size, input.data() + sizeof(rpc_tensor) + sizeof(offset), sizeof(size));
+
+ struct ggml_init_params params {
+ /*.mem_size =*/ ggml_tensor_overhead(),
+ /*.mem_buffer =*/ NULL,
+ /*.no_alloc =*/ true,
+ };
+ struct ggml_context * ctx = ggml_init(params);
+ ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
+ if (tensor == nullptr) {
+ GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
+ ggml_free(ctx);
+ return false;
+ }
+ GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
+ // output serialization format: | data (size bytes) |
+ output.resize(size, 0);
+ ggml_backend_tensor_get(tensor, output.data(), offset, size);
+ ggml_free(ctx);
+ return true;
+}
+
+bool rpc_server::copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
+ // serialization format: | rpc_tensor src | rpc_tensor dst |
+ if (input.size() != 2*sizeof(rpc_tensor)) {
+ return false;
+ }
+ const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
+ const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
+
+ struct ggml_init_params params {
+ /*.mem_size =*/ 2*ggml_tensor_overhead(),
+ /*.mem_buffer =*/ NULL,
+ /*.no_alloc =*/ true,
+ };
+ struct ggml_context * ctx = ggml_init(params);
+ ggml_tensor * src = deserialize_tensor(ctx, rpc_src);
+ ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst);
+ if (src == nullptr || dst == nullptr) {
+ GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
+ ggml_free(ctx);
+ return false;
+ }
+ GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
+ bool result = ggml_backend_buffer_copy_tensor(src, dst);
+ // output serialization format: | result (1 byte) |
+ output.resize(1, 0);
+ output[0] = result;
+ ggml_free(ctx);
+ return true;
+}
+
+ggml_tensor * rpc_server::create_node(uint64_t id,
+ struct ggml_context * ctx,
+ const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
+ std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
+ if (id == 0) {
+ return nullptr;
+ }
+ if (tensor_map.find(id) != tensor_map.end()) {
+ return tensor_map[id];
+ }
+ const rpc_tensor * tensor = tensor_ptrs.at(id);
+ struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
+ if (result == nullptr) {
+ return nullptr;
+ }
+ tensor_map[id] = result;
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
+ }
+ result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
+ result->view_offs = tensor->view_offs;
+ return result;
+}
+
+bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
+ // serialization format:
+ // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
+ if (input.size() < sizeof(uint32_t)) {
+ return false;
+ }
+ uint32_t n_nodes;
+ memcpy(&n_nodes, input.data(), sizeof(n_nodes));
+ if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
+ return false;
+ }
+ const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
+ uint32_t n_tensors;
+ memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
+ if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
+ return false;
+ }
+ const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
+ GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
+
+ static size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
+ struct ggml_init_params params = {
+ /*.mem_size =*/ buf_size,
+ /*.mem_buffer =*/ NULL,
+ /*.no_alloc =*/ true,
+ };
+ struct ggml_context * ctx = ggml_init(params);
+ struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
+ graph->n_nodes = n_nodes;
+ std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
+ for (uint32_t i = 0; i < n_tensors; i++) {
+ tensor_ptrs[tensors[i].id] = &tensors[i];
+ }
+ std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
+ for (uint32_t i = 0; i < n_nodes; i++) {
+ int64_t id;
+ memcpy(&id, &nodes[i], sizeof(id));
+ graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
+ }
+ ggml_status status = ggml_backend_graph_compute(backend, graph);
+ // output serialization format: | status (1 byte) |
+ output.resize(1, 0);
+ output[0] = status;
+ ggml_free(ctx);
+ return true;
+}
+
+rpc_server::~rpc_server() {
+ for (auto buffer : buffers) {
+ ggml_backend_buffer_free(buffer);
+ }
+}
+
+static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
+ rpc_server server(backend);
+ while (true) {
+ uint8_t cmd;
+ if (!recv_data(sockfd, &cmd, 1)) {
+ break;
+ }
+ std::vector<uint8_t> input;
+ std::vector<uint8_t> output;
+ uint64_t input_size;
+ if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
+ break;
+ }
+ input.resize(input_size);
+ if (!recv_data(sockfd, input.data(), input_size)) {
+ break;
+ }
+ bool ok = true;
+ switch (cmd) {
+ case ALLOC_BUFFER: {
+ ok = server.alloc_buffer(input, output);
+ break;
+ }
+ case GET_ALIGNMENT: {
+ server.get_alignment(output);
+ break;
+ }
+ case GET_MAX_SIZE: {
+ server.get_max_size(output);
+ break;
+ }
+ case BUFFER_GET_BASE: {
+ ok = server.buffer_get_base(input, output);
+ break;
+ }
+ case FREE_BUFFER: {
+ ok = server.free_buffer(input);
+ break;
+ }
+ case BUFFER_CLEAR: {
+ ok = server.buffer_clear(input);
+ break;
+ }
+ case SET_TENSOR: {
+ ok = server.set_tensor(input);
+ break;
+ }
+ case GET_TENSOR: {
+ ok = server.get_tensor(input, output);
+ break;
+ }
+ case COPY_TENSOR: {
+ ok = server.copy_tensor(input, output);
+ break;
+ }
+ case GRAPH_COMPUTE: {
+ ok = server.graph_compute(input, output);
+ break;
+ }
+ case GET_DEVICE_MEMORY: {
+ // output serialization format: | free (8 bytes) | total (8 bytes) |
+ output.resize(2*sizeof(uint64_t), 0);
+ memcpy(output.data(), &free_mem, sizeof(free_mem));
+ memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
+ break;
+ }
+ default: {
+ fprintf(stderr, "Unknown command: %d\n", cmd);
+ ok = false;
+ }
+ }
+ if (!ok) {
+ break;
+ }
+ uint64_t output_size = output.size();
+ if (!send_data(sockfd, &output_size, sizeof(output_size))) {
+ break;
+ }
+ if (!send_data(sockfd, output.data(), output_size)) {
+ break;
+ }
+ }
+}
+
+void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
+ std::string host;
+ int port;
+ if (!parse_endpoint(endpoint, host, port)) {
+ return;
+ }
+#ifdef _WIN32
+ {
+ WSADATA wsaData;
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
+ if (res != 0) {
+ fprintf(stderr, "WSAStartup failed: %d\n", res);
+ return;
+ }
+ }
+#endif
+ auto server_socket = create_server_socket(host.c_str(), port);
+ if (server_socket == nullptr) {
+ fprintf(stderr, "Failed to create server socket\n");
+ return;
+ }
+ while (true) {
+ auto client_socket = socket_accept(server_socket->fd);
+ if (client_socket == nullptr) {
+ fprintf(stderr, "Failed to accept client connection\n");
+ return;
+ }
+ printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
+ rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
+ printf("Client connection closed\n");
+ }
+#ifdef _WIN32
+ WSACleanup();
+#endif
+}
diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp
new file mode 100644
index 00000000..36518ff9
--- /dev/null
+++ b/ggml/src/ggml-sycl.cpp
@@ -0,0 +1,5314 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#include <algorithm>
+#include <assert.h>
+#include <atomic>
+#include <cinttypes>
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+#include <float.h>
+#include <limits>
+#include <stdint.h>
+#include <stdio.h>
+#include <vector>
+#include <cmath>
+#include <iostream>
+#include <fstream>
+#include <stdio.h>
+#include <stdlib.h>
+#include <regex>
+
+#include <sycl/sycl.hpp>
+#include <sycl/half_type.hpp>
+
+#include "ggml-sycl.h"
+#include "ggml.h"
+#include "ggml-backend-impl.h"
+
+#include "ggml-sycl/backend.hpp"
+#include "ggml-sycl/presets.hpp"
+
+bool ggml_sycl_loaded(void);
+void ggml_sycl_free_data(struct ggml_tensor * tensor);
+void ggml_sycl_copy_to_device(struct ggml_tensor * tensor);
+void ggml_sycl_set_main_device(int main_device);
+void ggml_sycl_set_mul_mat_q(bool mul_mat_q);
+void ggml_sycl_get_device_description(int device, char * description, size_t description_size);
+bool ggml_backend_is_sycl(ggml_backend_t backend);
+int ggml_backend_sycl_get_device(ggml_backend_t backend);
+static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer);
+static inline int get_sycl_env(const char *env_name, int default_val);
+
+
+void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
+ const void *ptr_src, size_t size) {
+ char *host_buf = (char *)malloc(size);
+ q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
+ q_dst.memcpy((char *)ptr_dst, host_buf, size).wait();
+ free(host_buf);
+}
+
+typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
+typedef void (*ggml_sycl_func_t)(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+typedef void (*ggml_sycl_op_mul_mat_t)(
+ ggml_backend_sycl_context & ctx,
+ const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+ const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
+ float *dst_dd_i, const int64_t row_low, const int64_t row_high,
+ const int64_t src1_ncols, const int64_t src1_padded_row_size,
+ const queue_ptr &stream);
+typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1,
+ ggml_tensor *dst, const float *src0_dd,
+ const float *src1_dd, float *dst_dd,
+ const queue_ptr &main_stream);
+
+static __dpct_inline__ float op_repeat(const float a, const float b) {
+ return b;
+ GGML_UNUSED(a);
+}
+
+static __dpct_inline__ float op_add(const float a, const float b) {
+ return a + b;
+}
+
+static __dpct_inline__ float op_mul(const float a, const float b) {
+ return a * b;
+}
+
+static __dpct_inline__ float op_div(const float a, const float b) {
+ return a / b;
+}
+
+template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
+ int ne0, int ne1, int ne2, int ne3,
+ int ne10, int ne11, int ne12, int ne13,
+ /*int s0, */ int s1, int s2, int s3,
+ /*int s10,*/ int s11, int s12, int s13,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+ const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+ item_ct1.get_local_id(1));
+ const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
+ item_ct1.get_local_id(0)) /
+ ne3;
+ const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
+ item_ct1.get_local_id(0)) %
+ ne3;
+
+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+ return;
+ }
+
+ const int i11 = i1 % ne11;
+ const int i12 = i2 % ne12;
+ const int i13 = i3 % ne13;
+
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+ const size_t i_dst = i_src0;
+
+ const src0_t * src0_row = src0 + i_src0;
+ const src1_t * src1_row = src1 + i_src1;
+ dst_t * dst_row = dst + i_dst;
+
+ for (int i0 = i0s; i0 < ne0;
+ i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
+ const int i10 = i0 % ne10;
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+ }
+}
+
+template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
+ int ne0, int ne1, int ne2, int ne3,
+ int ne10, int ne11, int ne12, int ne13,
+ /*int s0, */ int s1, int s2, int s3,
+ /*int s10,*/ int s11, int s12, int s13,
+ const sycl::nd_item<3> &item_ct1) {
+
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+
+ const int i3 = i/(ne2*ne1*ne0);
+ const int i2 = (i/(ne1*ne0)) % ne2;
+ const int i1 = (i/ne0) % ne1;
+ const int i0 = i % ne0;
+
+ if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+ return;
+ }
+
+ const int i11 = i1 % ne11;
+ const int i12 = i2 % ne12;
+ const int i13 = i3 % ne13;
+
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+ const size_t i_dst = i_src0;
+
+ const src0_t * src0_row = src0 + i_src0;
+ const src1_t * src1_row = src1 + i_src1;
+ dst_t * dst_row = dst + i_dst;
+
+ const int i10 = i0 % ne10;
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+}
+
+static void acc_f32(const float * x, const float * y, float * dst, const int ne,
+ const int ne10, const int ne11, const int ne12,
+ const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+ if (i >= ne) {
+ return;
+ }
+ int src1_idx = i - offset;
+ int oz = src1_idx / nb2;
+ int oy = (src1_idx - (oz * nb2)) / nb1;
+ int ox = src1_idx % nb1;
+ if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
+ dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
+ } else {
+ dst[i] = x[i];
+ }
+}
+
+static void gelu_f32(const float * x, float * dst, const int k,
+ const sycl::nd_item<3> &item_ct1) {
+ const float GELU_COEF_A = 0.044715f;
+ const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+
+ if (i >= k) {
+ return;
+ }
+
+ float xi = x[i];
+ dst[i] = 0.5f * xi *
+ (1.0f +
+ sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
+}
+
+static void silu_f32(const float * x, float * dst, const int k,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
+}
+
+static void gelu_quick_f32(const float *x, float *dst, int k,
+ const sycl::nd_item<3> &item_ct1) {
+ const float GELU_QUICK_COEF = -1.702f;
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+ if (i >= k) {
+ return;
+ }
+ dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
+}
+
+static void tanh_f32(const float *x, float *dst, int k,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+ if (i >= k) {
+ return;
+ }
+ dst[i] = sycl::tanh((float)(x[i]));
+}
+
+static void relu_f32(const float * x, float * dst, const int k,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = sycl::fmax((float)(x[i]), (float)0);
+}
+
+static void hardsigmoid_f32(const float * x, float * dst, const int k,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
+}
+
+static void hardswish_f32(const float * x, float * dst, const int k,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
+}
+
+static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+ if (i >= k) {
+ return;
+ }
+ dst[i] = sycl::fmax((float)(x[i]), (float)0) +
+ sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
+}
+
+static void sqr_f32(const float * x, float * dst, const int k,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = x[i] * x[i];
+}
+
+static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
+ const int nb02, const int nb03, const int ne10, const int ne11,
+ const int ne12, const int ne13, const float sf0, const float sf1,
+ const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
+ int index = item_ct1.get_local_id(0) +
+ item_ct1.get_group(0) * item_ct1.get_local_range(0);
+ if (index >= ne10 * ne11 * ne12 * ne13) {
+ return;
+ }
+ // operation
+ int i10 = index % ne10;
+ int i11 = (index / ne10) % ne11;
+ int i12 = (index / (ne10 * ne11)) % ne12;
+ int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
+
+ int i00 = i10 / sf0;
+ int i01 = i11 / sf1;
+ int i02 = i12 / sf2;
+ int i03 = i13 / sf3;
+
+ dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
+}
+
+static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
+ const sycl::nd_item<3> &item_ct1) {
+ int nidx = item_ct1.get_local_id(2) +
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
+ if (nidx >= ne0) {
+ return;
+ }
+
+ // operation
+ int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
+ item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
+ if (nidx < ne00 && item_ct1.get_group(1) < ne01 &&
+ item_ct1.get_group(0) < ne02) {
+ int offset_src = nidx + item_ct1.get_group(1) * ne00 +
+ item_ct1.get_group(0) * ne00 * ne01;
+ dst[offset_dst] = x[offset_src];
+ } else {
+ dst[offset_dst] = 0.0f;
+ }
+}
+
+template<int QUANT_BLOCK_TILE>
+static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
+ const sycl::nd_item<3> &item_ct1) {
+ const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
+
+ if (ix >= kx_padded) {
+ return;
+ }
+
+ const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+ item_ct1.get_local_id(1);
+
+ const int i_padded = iy*kx_padded + ix;
+
+ block_q8_1 * y = (block_q8_1 *) vy;
+
+ const int ib = i_padded / QK8_1; // block index
+ const int iqs = i_padded % QK8_1; // quant index
+ typedef sycl::vec<float, QUANT_BLOCK_TILE> TC;
+ typedef sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ;
+ TC zeros;
+ TQ qzeros;
+#pragma unroll
+ for (int i = 0; i < QUANT_BLOCK_TILE; i++)
+ {
+ zeros[i] = 0.f;
+ qzeros[i] = 0;
+ }
+ const TC xi = ix < kx ? *(TC *)&x[iy * kx + ix] : zeros;
+ float sum = xi[0];
+ float amax = sycl::fabs(xi[0]);
+#pragma unroll
+ for (int i = 1; i < QUANT_BLOCK_TILE; i++)
+ {
+ sum += xi[i];
+ amax = sycl::fmax(sycl::fabs(xi[i]), amax);
+ }
+ sum = warp_reduce_sum(sum, item_ct1);
+ amax = warp_reduce_max(amax, item_ct1);
+
+ const float d = amax / 127;
+ TQ q = qzeros;
+ if (amax != 0.0f)
+ {
+#pragma unroll
+ for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
+ q[i] = sycl::round(xi[i] / d);
+ }
+ }
+
+ *(TQ *)&y[ib].qs[iqs] = q;
+
+ if (iqs > 0) {
+ return;
+ }
+
+ reinterpret_cast<sycl::half &>(y[ib].ds.x()) = d;
+ reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
+}
+
+template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void k_get_rows(
+ const void * src0, const int32_t * src1, dst_t * dst,
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+ size_t s10, size_t s11, size_t s12,
+ const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
+
+ const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
+ item_ct1.get_local_id(2)) *
+ 2;
+ const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+ item_ct1.get_local_id(1);
+ const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+ item_ct1.get_local_id(0)) /
+ ne12;
+ const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+ item_ct1.get_local_id(0)) %
+ ne12;
+
+ if (i00 >= ne00) {
+ return;
+ }
+
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
+
+ const int ib = i00/qk; // block index
+ const int iqs = (i00%qk)/qr; // quant index
+ const int iybs = i00 - i00%qk; // dst block start index
+ const int y_offset = qr == 1 ? 1 : qk/2;
+
+ // dequantize
+ dfloat2 v;
+ dequantize_kernel(src0_row, ib, iqs, v);
+
+ dst_row[iybs + iqs + 0] = v.x();
+ dst_row[iybs + iqs + y_offset] = v.y();
+}
+
+template<typename src0_t, typename dst_t>
+static void k_get_rows_float(
+ const src0_t * src0, const int32_t * src1, dst_t * dst,
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+ size_t s10, size_t s11, size_t s12,
+ const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
+
+ const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
+ item_ct1.get_local_id(2);
+ const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+ item_ct1.get_local_id(1);
+ const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+ item_ct1.get_local_id(0)) /
+ ne12;
+ const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+ item_ct1.get_local_id(0)) %
+ ne12;
+
+ if (i00 >= ne00) {
+ return;
+ }
+
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+
+ dst_row[i00] = src0_row[i00];
+}
+
+static void mul_mat_p021_f16_f32(
+ const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
+ const sycl::nd_item<3> &item_ct1) {
+
+ const sycl::half *x = (const sycl::half *)vx;
+
+ const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+ item_ct1.get_local_id(1);
+ const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
+ item_ct1.get_local_id(0);
+ const int channel_x = channel / (nchannels_y / nchannels_x);
+
+ const int nrows_y = ncols_x;
+ const int nrows_dst = nrows_x;
+ const int row_dst = row_x;
+
+ float tmp = 0.0f;
+
+ for (int col_x0 = 0; col_x0 < ncols_x;
+ col_x0 += item_ct1.get_local_range(2)) {
+ const int col_x = col_x0 + item_ct1.get_local_id(2);
+
+ if (col_x >= ncols_x) {
+ break;
+ }
+
+ // x is transposed and permuted
+ const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
+ const float xi =
+ sycl::vec<sycl::half, 1>(x[ix])
+ .convert<float, sycl::rounding_mode::automatic>()[0];
+
+ const int row_y = col_x;
+
+
+ // y is not transposed but permuted
+ const int iy = channel*nrows_y + row_y;
+
+ tmp += xi * y[iy];
+ }
+
+ // dst is not transposed and not permuted
+ const int idst = channel*nrows_dst + row_dst;
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (item_ct1.get_local_id(2) == 0) {
+ dst[idst] = tmp;
+ }
+}
+
+static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
+ const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
+ const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
+ const sycl::nd_item<3> &item_ct1) {
+
+ const sycl::half *x = (const sycl::half *)vx;
+
+ const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+ item_ct1.get_local_id(1);
+ const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
+ item_ct1.get_local_id(0);
+ const int channel_x = channel / channel_x_divisor;
+
+ const int nrows_y = ncols_x;
+ const int nrows_dst = nrows_x;
+ const int row_dst = row_x;
+
+ const int idst = channel*nrows_dst + row_dst;
+
+ float tmp = 0.0f;
+
+ for (int col_x0 = 0; col_x0 < ncols_x;
+ col_x0 += item_ct1.get_local_range(2)) {
+ const int col_x = col_x0 + item_ct1.get_local_id(2);
+
+ if (col_x >= ncols_x) {
+ break;
+ }
+
+ const int row_y = col_x;
+
+ const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
+ const int iy = channel*nrows_y + row_y;
+
+ const float xi =
+ sycl::vec<sycl::half, 1>(x[ix])
+ .convert<float, sycl::rounding_mode::automatic>()[0];
+
+ tmp += xi * y[iy];
+ }
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (item_ct1.get_local_id(2) == 0) {
+ dst[idst] = tmp;
+ }
+}
+
+static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ float * dsti = (float *) cdsti;
+
+ *dsti = *xi;
+}
+
+static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ sycl::half *dsti = (sycl::half *)cdsti;
+
+ *dsti = sycl::vec<float, 1>(*xi)
+ .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
+}
+
+static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
+ const sycl::half *xi = (const sycl::half *)cxi;
+ sycl::half *dsti = (sycl::half *)cdsti;
+
+ *dsti = *xi;
+}
+
+static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
+ const sycl::half *xi = (const sycl::half *)cxi;
+ float * dsti = (float *) cdsti;
+
+ *dsti = *xi;
+}
+
+static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
+ const int16_t *xi = (const int16_t *)cxi;
+ int16_t *dsti = (int16_t *)cdsti;
+
+ *dsti = *xi;
+}
+
+static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
+ const int32_t *xi = (const int32_t *)cxi;
+ int32_t *dsti = (int32_t *)cdsti;
+
+ *dsti = *xi;
+}
+
+template <cpy_kernel_t cpy_1>
+static void cpy_f32_f16(const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+ const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+
+ if (i >= ne) {
+ return;
+ }
+
+ // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
+ // then combine those indices with the corresponding byte offsets to get the total offsets
+ const int i03 = i/(ne00 * ne01 * ne02);
+ const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+ const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
+ const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+ const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+ const int i13 = i/(ne10 * ne11 * ne12);
+ const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+ const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+ const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+ const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
+
+ cpy_1(cx + x_offset, cdst + dst_offset);
+}
+
+static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ block_q8_0 * dsti = (block_q8_0 *) cdsti;
+
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_0; j++) {
+ const float v = xi[j];
+ amax = sycl::fmax(amax, sycl::fabs((float)v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dsti->d = d;
+
+ for (int j = 0; j < QK8_0; ++j) {
+ const float x0 = xi[j]*id;
+
+ dsti->qs[j] = sycl::round((float)x0);
+ }
+}
+
+static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ block_q4_0 * dsti = (block_q4_0 *) cdsti;
+
+ float amax = 0.0f;
+ float vmax = 0.0f;
+
+ for (int j = 0; j < QK4_0; ++j) {
+ const float v = xi[j];
+ if (amax < sycl::fabs((float)v)) {
+ amax = sycl::fabs((float)v);
+ vmax = v;
+ }
+ }
+
+ const float d = vmax / -8;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dsti->d = d;
+
+ for (int j = 0; j < QK4_0/2; ++j) {
+ const float x0 = xi[0 + j]*id;
+ const float x1 = xi[QK4_0/2 + j]*id;
+
+ const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 8.5f));
+ const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 8.5f));
+
+ dsti->qs[j] = xi0;
+ dsti->qs[j] |= xi1 << 4;
+ }
+}
+
+static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
+ const float * xi = (const float *) cxi;
+ block_q4_1 * dsti = (block_q4_1 *) cdsti;
+
+ float vmin = FLT_MAX;
+ float vmax = -FLT_MAX;
+
+ for (int j = 0; j < QK4_1; ++j) {
+ const float v = xi[j];
+
+ if (v < vmin) vmin = v;
+ if (v > vmax) vmax = v;
+ }
+
+ const float d = (vmax - vmin) / ((1 << 4) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dsti->dm.x() = d;
+ dsti->dm.y() = vmin;
+
+ for (int j = 0; j < QK4_1/2; ++j) {
+ const float x0 = (xi[0 + j] - vmin)*id;
+ const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
+
+ const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 0.5f));
+ const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 0.5f));
+
+ dsti->qs[j] = xi0;
+ dsti->qs[j] |= xi1 << 4;
+ }
+}
+
+template <cpy_kernel_t cpy_blck, int qk>
+static void cpy_f32_q(const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+ const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
+ const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2)) *
+ qk;
+
+ if (i >= ne) {
+ return;
+ }
+
+ const int i03 = i/(ne00 * ne01 * ne02);
+ const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+ const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
+ const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+ const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+ const int i13 = i/(ne10 * ne11 * ne12);
+ const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+ const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+ const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+ const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+
+ cpy_blck(cx + x_offset, cdst + dst_offset);
+}
+
+static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
+ const sycl::nd_item<3> &item_ct1) {
+ const int row = item_ct1.get_group(1);
+ const int col = item_ct1.get_local_id(2);
+
+ float sum = 0.0f;
+ for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) {
+ sum += x[row * ncols + i];
+ }
+
+ sum = warp_reduce_sum(sum, item_ct1);
+
+ if (col == 0) {
+ dst[row] = sum;
+ }
+}
+
+
+template<typename T>
+static inline void ggml_sycl_swap(T & a, T & b) {
+ T tmp = a;
+ a = b;
+ b = tmp;
+}
+
+template <ggml_sort_order order>
+__dpct_inline__ static void
+k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
+ const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
+ // bitonic sort
+ int col = item_ct1.get_local_id(2);
+ int row = item_ct1.get_group(1);
+
+ if (col >= ncols_pad) {
+ return;
+ }
+
+ const float * x_row = x + row * ncols;
+ auto dst_row = (int *)dpct_local;
+
+ // initialize indices
+ dst_row[col] = col;
+
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ for (int k = 2; k <= ncols_pad; k *= 2) {
+ for (int j = k / 2; j > 0; j /= 2) {
+ int ixj = col ^ j;
+ if (ixj > col) {
+ if ((col & k) == 0) {
+ if (dst_row[col] >= ncols ||
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+ ) {
+ ggml_sycl_swap(dst_row[col], dst_row[ixj]);
+ }
+ } else {
+ if (dst_row[ixj] >= ncols ||
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+ ) {
+ ggml_sycl_swap(dst_row[col], dst_row[ixj]);
+ }
+ }
+ }
+ /*
+ DPCT1118:1: SYCL group functions and algorithms must be encountered
+ in converged control flow. You may need to adjust the code.
+ */
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+ }
+ }
+
+ // copy the result to dst without the padding
+ if (col < ncols) {
+ dst[row * ncols + col] = dst_row[col];
+ }
+}
+
+
+static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
+ const sycl::nd_item<3> &item_ct1) {
+ const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+ item_ct1.get_local_id(1);
+ const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+
+ if (col >= ncols) {
+ return;
+ }
+
+ const int i = row*ncols + col;
+ //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
+ //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+ dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
+}
+
+static void scale_f32(const float * x, float * dst, const float scale, const int k,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+
+ if (i >= k) {
+ return;
+ }
+
+ dst[i] = scale * x[i];
+}
+
+static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+
+ if (i >= k) {
+ return;
+ }
+
+ dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
+}
+
+template <typename T>
+static void im2col_kernel(const float *x, T *dst, int offset_delta,
+ int IW, int IH, int OW, int KW, int KH,
+ int pelements, int CHW, int s0, int s1, int p0,
+ int p1, int d0, int d1,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i = item_ct1.get_local_id(2) +
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
+ if (i >= pelements) {
+ return;
+ }
+
+ const int ksize = OW * (KH > 1 ? KW : 1);
+ const int kx = i / ksize;
+ const int kd = kx * ksize;
+ const int ky = (i - kd) / OW;
+ const int ix = i % OW;
+
+ const int64_t iiw = ix * s0 + kx * d0 - p0;
+ const int64_t iih = item_ct1.get_group(1) * s1 + ky * d1 - p1;
+
+ const int64_t offset_dst =
+ (item_ct1.get_group(1) * OW + ix) * CHW +
+ (item_ct1.get_group(0) * (KW * KH) + ky * KW + kx);
+
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+ dst[offset_dst] =
+ sycl::vec<float, 1>(0.0f)
+ .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
+ } else {
+ const int64_t offset_src = item_ct1.get_group(0) * offset_delta;
+ dst[offset_dst] =
+ sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
+ .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
+ }
+}
+
+template <typename Ti, typename To>
+static void pool2d_nchw_kernel(
+ const int ih, const int iw, const int oh, const int ow,
+ const int kh, const int kw, const int sh, const int sw,
+ const int ph, const int pw, const int parallel_elements,
+ const Ti* src, To* dst, const enum ggml_op_pool op,
+ const sycl::nd_item<3> &item_ct1) {
+ int idx = item_ct1.get_local_id(2) +
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
+ if (idx >= parallel_elements) {
+ return;
+ }
+
+ const int I_HW = ih * iw;
+ const int O_HW = oh * ow;
+ const int nc = idx / O_HW;
+ const int cur_oh = idx % O_HW / ow;
+ const int cur_ow = idx % O_HW % ow;
+ const Ti* i_ptr = src + nc * I_HW;
+ To* o_ptr = dst + nc * O_HW;
+ const int start_h = cur_oh * sh - ph;
+ const int bh = sycl::max(0, start_h);
+ const int eh = sycl::min(ih, start_h + kh);
+ const int start_w = cur_ow * sw - pw;
+ const int bw = sycl::max(0, start_w);
+ const int ew = sycl::min(iw, start_w + kw);
+
+ To res = 0;
+
+ switch (op) {
+ case GGML_OP_POOL_AVG: res = 0; break;
+ case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
+ }
+
+ for (int i = bh; i < eh; i += 1) {
+ for (int j = bw; j < ew; j += 1) {
+#if DPCT_COMPATIBILITY_TEMP >= 350
+ /*
+ DPCT1098:106: The '*' expression is used instead of the __ldg
+ call. These two expressions do not provide the exact same
+ functionality. Check the generated code for potential precision
+ and/or performance issues.
+ */
+ Ti cur = *(i_ptr + i * iw + j);
+#else
+ Ti cur = i_ptr[i * iw + j];
+#endif
+ switch (op) {
+ case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
+ case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
+ }
+ }
+ }
+ o_ptr[cur_oh * ow + cur_ow] = res;
+}
+
+template <int qk, int qr, dequantize_kernel_t dq>
+static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+ ggml_tensor *dst, const void *src0_dd,
+ const int32_t *src1_dd, float *dst_dd,
+ queue_ptr stream) {
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
+ const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
+ const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
+
+ // strides in elements
+ //const size_t s0 = nb0 / ggml_element_size(dst);
+ const size_t s1 = nb1 / ggml_element_size(dst);
+ const size_t s2 = nb2 / ggml_element_size(dst);
+ const size_t s3 = nb3 / ggml_element_size(dst);
+
+ const size_t s10 = nb10 / ggml_element_size(src1);
+ const size_t s11 = nb11 / ggml_element_size(src1);
+ const size_t s12 = nb12 / ggml_element_size(src1);
+ //const size_t s13 = nb13 / ggml_element_size(src1);
+
+ GGML_ASSERT(ne00 % 2 == 0);
+
+ stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ k_get_rows<qk, qr, dq>(
+ src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
+ s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
+ });
+
+ (void) dst;
+}
+
+template <typename src0_t>
+static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const src0_t *src0_dd, const int32_t *src1_dd,
+ float *dst_dd, queue_ptr stream) {
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
+ const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
+ const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
+
+ // strides in elements
+ //const size_t s0 = nb0 / ggml_element_size(dst);
+ const size_t s1 = nb1 / ggml_element_size(dst);
+ const size_t s2 = nb2 / ggml_element_size(dst);
+ const size_t s3 = nb3 / ggml_element_size(dst);
+
+ const size_t s10 = nb10 / ggml_element_size(src1);
+ const size_t s11 = nb11 / ggml_element_size(src1);
+ const size_t s12 = nb12 / ggml_element_size(src1);
+ //const size_t s13 = nb13 / ggml_element_size(src1);
+
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
+ s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
+ });
+ }
+
+ (void) dst;
+}
+
+template<float (*bin_op)(const float, const float)>
+struct bin_bcast_sycl {
+ template <typename src0_t, typename src1_t, typename dst_t>
+ void operator()(ggml_backend_sycl_context & ctx,
+ const struct ggml_tensor *src0,
+ const struct ggml_tensor *src1, struct ggml_tensor *dst,
+ const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
+ queue_ptr stream) {
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ int nr0 = ne10/ne0;
+ int nr1 = ne11/ne1;
+ int nr2 = ne12/ne2;
+ int nr3 = ne13/ne3;
+
+ int nr[4] = { nr0, nr1, nr2, nr3 };
+
+ // collapse dimensions until first broadcast dimension
+ int64_t cne0[] = {ne0, ne1, ne2, ne3};
+ int64_t cne1[] = {ne10, ne11, ne12, ne13};
+ size_t cnb0[] = {nb0, nb1, nb2, nb3};
+ size_t cnb1[] = {nb10, nb11, nb12, nb13};
+ auto collapse = [](int64_t cne[]) {
+ cne[0] *= cne[1];
+ cne[1] = cne[2];
+ cne[2] = cne[3];
+ cne[3] = 1;
+ };
+
+ auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
+ cnb[1] *= cne[1];
+ cnb[2] *= cne[2];
+ cnb[3] *= cne[3];
+ };
+
+ for (int i = 0; i < 4; i++) {
+ if (nr[i] != 1) {
+ break;
+ }
+ if (i > 0) {
+ collapse_nb(cnb0, cne0);
+ collapse_nb(cnb1, cne1);
+ collapse(cne0);
+ collapse(cne1);
+ }
+ }
+ {
+ int64_t ne0 = cne0[0];
+ int64_t ne1 = cne0[1];
+ int64_t ne2 = cne0[2];
+ int64_t ne3 = cne0[3];
+
+ int64_t ne10 = cne1[0];
+ int64_t ne11 = cne1[1];
+ int64_t ne12 = cne1[2];
+ int64_t ne13 = cne1[3];
+
+ size_t nb0 = cnb0[0];
+ size_t nb1 = cnb0[1];
+ size_t nb2 = cnb0[2];
+ size_t nb3 = cnb0[3];
+
+ size_t nb10 = cnb1[0];
+ size_t nb11 = cnb1[1];
+ size_t nb12 = cnb1[2];
+ size_t nb13 = cnb1[3];
+
+ size_t s0 = nb0 / sizeof(dst_t);
+ size_t s1 = nb1 / sizeof(dst_t);
+ size_t s2 = nb2 / sizeof(dst_t);
+ size_t s3 = nb3 / sizeof(dst_t);
+
+ size_t s10 = nb10 / sizeof(src1_t);
+ size_t s11 = nb11 / sizeof(src1_t);
+ size_t s12 = nb12 / sizeof(src1_t);
+ size_t s13 = nb13 / sizeof(src1_t);
+
+ GGML_ASSERT(s0 == 1);
+ GGML_ASSERT(s10 == 1);
+
+ const int block_size = 128;
+
+ int64_t hne0 = std::max(ne0/2LL, 1LL);
+
+ sycl::range<3> block_dims(1, 1, 1);
+ block_dims[2] = std::min<unsigned int>(hne0, block_size);
+ block_dims[1] = std::min<unsigned int>(
+ ne1, block_size / (unsigned int)block_dims[2]);
+ block_dims[0] = std::min(
+ std::min<unsigned int>(
+ ne2 * ne3, block_size / (unsigned int)block_dims[2] /
+ (unsigned int)block_dims[1]),
+ 64U);
+
+ sycl::range<3> block_nums(
+ (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
+ (ne1 + block_dims[1] - 1) / block_dims[1],
+ (hne0 + block_dims[2] - 1) / block_dims[2]);
+
+ if (block_nums[0] > 65535) {
+ // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
+ int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
+ sycl::range<3>(1, 1, block_size),
+ sycl::range<3>(1, 1, block_size)),
+ [=](sycl::nd_item<3> item_ct1) {
+ k_bin_bcast_unravel<bin_op>(
+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
+ ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
+ s13, item_ct1);
+ });
+ }
+ } else {
+ /*
+ DPCT1049:16: The work-group size passed to the SYCL kernel may
+ exceed the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if
+ needed.
+ */
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
+ ne2, ne3, ne10, ne11, ne12, ne13,
+ s1, s2, s3, s11, s12, s13,
+ item_ct1);
+ });
+ }
+ }
+ }
+};
+
+static void acc_f32_sycl(const float *x, const float *y, float *dst,
+ const int n_elements, const int ne10, const int ne11,
+ const int ne12, const int nb1, const int nb2,
+ const int offset, queue_ptr stream) {
+ int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
+ item_ct1);
+ });
+}
+
+static void gelu_f32_sycl(const float *x, float *dst, const int k,
+ queue_ptr stream) {
+ const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ gelu_f32(x, dst, k, item_ct1);
+ });
+}
+
+static void silu_f32_sycl(const float *x, float *dst, const int k,
+ queue_ptr stream) {
+ const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ silu_f32(x, dst, k, item_ct1);
+ });
+}
+
+static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
+ queue_ptr stream) {
+ const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ gelu_quick_f32(x, dst, k, item_ct1);
+ });
+}
+
+static void tanh_f32_sycl(const float *x, float *dst, const int k,
+ queue_ptr stream) {
+ const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ tanh_f32(x, dst, k, item_ct1);
+ });
+}
+
+static void relu_f32_sycl(const float *x, float *dst, const int k,
+ queue_ptr stream) {
+ const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ relu_f32(x, dst, k, item_ct1);
+ });
+}
+
+static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
+ queue_ptr stream) {
+ const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ hardsigmoid_f32(x, dst, k, item_ct1);
+ });
+}
+
+static void hardswish_f32_sycl(const float *x, float *dst, const int k,
+ queue_ptr stream) {
+ const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ hardswish_f32(x, dst, k, item_ct1);
+ });
+}
+
+static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
+ const float negative_slope,
+ queue_ptr stream) {
+ const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ leaky_relu_f32(x, dst, k, negative_slope, item_ct1);
+ });
+}
+
+static void sqr_f32_sycl(const float *x, float *dst, const int k,
+ queue_ptr stream) {
+ const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ sqr_f32(x, dst, k, item_ct1);
+ });
+}
+
+static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
+ const int nb02, const int nb03, const int ne10, const int ne11,
+ const int ne12, const int ne13, const float sf0, const float sf1,
+ const float sf2, const float sf3, queue_ptr stream) {
+ int dst_size = ne10 * ne11 * ne12 * ne13;
+ int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
+ sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
+ stream->parallel_for(
+ sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
+ [=](sycl::nd_item<1> item_ct1) {
+ upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
+ });
+}
+
+static void pad_f32_sycl(const float *x, float *dst, const int ne00,
+ const int ne01, const int ne02, const int ne0,
+ const int ne1, const int ne2, queue_ptr stream) {
+ int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
+ sycl::range<3> gridDim(ne2, ne1, num_blocks);
+ stream->parallel_for(
+ sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1);
+ });
+}
+
+static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
+ const int ky, const int kx_padded,
+ queue_ptr stream) {
+ const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
+ const sycl::range<3> num_blocks(1, ky, block_num_x);
+ int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
+ static_assert(QK8_1 % WARP_SIZE == 0);
+ const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(num_blocks * block_size, block_size),
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
+ });
+ }
+}
+
+static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
+ float *dst, const int ncols_x,
+ const int nrows_x,
+ const int nchannels_x,
+ const int nchannels_y,
+ queue_ptr stream) {
+
+ const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
+ const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
+ nchannels_y, item_ct1);
+ });
+ }
+}
+
+static void ggml_mul_mat_vec_nc_f16_f32_sycl(
+ const void *vx, const float *y, float *dst, const int ncols_x,
+ const int nrows_x, const int row_stride_x, const int nchannels_x,
+ const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
+
+ const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
+ const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
+ row_stride_x, channel_stride_x,
+ nchannels_y / nchannels_x, item_ct1);
+ });
+ }
+}
+
+static void
+ggml_cpy_f16_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00,
+ const int ne01, const int ne02, const int nb00,
+ const int nb01, const int nb02, const int nb03,
+ const int ne10, const int ne11, const int ne12,
+ const int nb10, const int nb11, const int nb12,
+ const int nb13, queue_ptr stream) {
+
+ const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00,
+ nb01, nb02, nb03, ne10, ne11, ne12,
+ nb10, nb11, nb12, nb13, item_ct1);
+ });
+ }
+}
+
+static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne,
+ const int ne00, const int ne01,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
+ queue_ptr stream) {
+
+ const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+ item_ct1);
+ });
+ }
+}
+
+static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne,
+ const int ne00, const int ne01,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
+ queue_ptr stream) {
+
+ const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+ item_ct1);
+ });
+ }
+}
+
+static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne,
+ const int ne00, const int ne01,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
+ queue_ptr stream) {
+
+ GGML_ASSERT(ne % QK8_0 == 0);
+ const int num_blocks = ne / QK8_0;
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
+ sycl::range<3>(1, 1, 1)),
+ [=](sycl::nd_item<3> item_ct1) {
+ cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+ item_ct1);
+ });
+}
+
+static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne,
+ const int ne00, const int ne01,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
+ queue_ptr stream) {
+
+ GGML_ASSERT(ne % QK4_0 == 0);
+ const int num_blocks = ne / QK4_0;
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
+ sycl::range<3>(1, 1, 1)),
+ [=](sycl::nd_item<3> item_ct1) {
+ cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+ item_ct1);
+ });
+}
+
+static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne,
+ const int ne00, const int ne01,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
+ queue_ptr stream) {
+
+ GGML_ASSERT(ne % QK4_1 == 0);
+ const int num_blocks = ne / QK4_1;
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
+ sycl::range<3>(1, 1, 1)),
+ [=](sycl::nd_item<3> item_ct1) {
+ cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+ item_ct1);
+ });
+}
+
+static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne,
+ const int ne00, const int ne01,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
+ queue_ptr stream) {
+
+ const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+ item_ct1);
+ });
+ }
+}
+
+static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne,
+ const int ne00, const int ne01,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
+ queue_ptr stream) {
+
+ const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
+ {
+ // dpct::has_capability_or_fail(stream->get_device(),
+ // {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+ item_ct1);
+ });
+ }
+}
+
+static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne,
+ const int ne00, const int ne01,
+ const int ne02, const int nb00,
+ const int nb01, const int nb02,
+ const int nb03, const int ne10,
+ const int ne11, const int ne12,
+ const int nb10, const int nb11,
+ const int nb12, const int nb13,
+ queue_ptr stream) {
+
+ const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
+ {
+ // dpct::has_capability_or_fail(stream->get_device(),
+ // {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+ nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+ item_ct1);
+ });
+ }
+}
+
+static void scale_f32_sycl(const float *x, float *dst, const float scale,
+ const int k, queue_ptr stream) {
+ const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ scale_f32(x, dst, scale, k, item_ct1);
+ });
+}
+
+static void clamp_f32_sycl(const float *x, float *dst, const float min,
+ const float max, const int k,
+ queue_ptr stream) {
+ const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ clamp_f32(x, dst, min, max, k, item_ct1);
+ });
+}
+
+static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
+ const int nrows, queue_ptr stream) {
+ const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+ const sycl::range<3> block_nums(1, nrows, 1);
+ stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ k_sum_rows_f32(x, dst, ncols, item_ct1);
+ });
+}
+
+static int next_power_of_2(int x) {
+ int n = 1;
+ while (n < x) {
+ n *= 2;
+ }
+ return n;
+}
+
+static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
+ const int nrows, ggml_sort_order order,
+ queue_ptr stream) {
+ // bitonic sort requires ncols to be power of 2
+ const int ncols_pad = next_power_of_2(ncols);
+
+ const sycl::range<3> block_dims(1, 1, ncols_pad);
+ const sycl::range<3> block_nums(1, nrows, 1);
+ const size_t shared_mem = ncols_pad * sizeof(int);
+
+ if (order == GGML_SORT_ORDER_ASC) {
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
+ sycl::range<1>(shared_mem), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
+ x, dst, ncols, ncols_pad, item_ct1,
+ dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
+ .get());
+ });
+ });
+ } else if (order == GGML_SORT_ORDER_DESC) {
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
+ sycl::range<1>(shared_mem), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
+ x, dst, ncols, ncols_pad, item_ct1,
+ dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
+ .get());
+ });
+ });
+ } else {
+ GGML_ASSERT(false);
+ }
+}
+
+static void diag_mask_inf_f32_sycl(const float *x, float *dst,
+ const int ncols_x, const int nrows_x,
+ const int rows_per_channel, const int n_past,
+ queue_ptr stream) {
+ const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1);
+ const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE;
+ const sycl::range<3> block_nums(1, block_num_x, nrows_x);
+ stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ diag_mask_inf_f32(x, dst, ncols_x,
+ rows_per_channel, n_past,
+ item_ct1);
+ });
+}
+
+template <typename T>
+static void im2col_sycl(const float *x, T *dst, int IW, int IH,
+ int OW, int OH, int KW, int KH, int IC,
+ int offset_delta, int s0, int s1, int p0,
+ int p1, int d0, int d1,
+ queue_ptr stream) {
+ const int parallel_elements = OW * KW * KH;
+ const int num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
+ sycl::range<3> block_nums(IC, OH, num_blocks);
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums *
+ sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH,
+ parallel_elements, (IC * KH * KW), s0, s1, p0,
+ p1, d0, d1, item_ct1);
+ });
+ }
+}
+
+
+static bool g_sycl_loaded = false;
+
+bool ggml_sycl_loaded(void) {
+ return g_sycl_loaded;
+}
+
+void print_device_detail(int id, sycl::device &device, std::string device_type) {
+
+ dpct::device_info prop;
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ dpct::get_device_info(prop, device)));
+
+ std::string version;
+ version += std::to_string(prop.get_major_version());
+ version += ".";
+ version += std::to_string(prop.get_minor_version());
+
+ device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), "");
+ std::string name = std::string(prop.get_name());
+ name = std::regex_replace(name, std::regex("\\(R\\)"), "");
+ name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
+
+ auto global_mem_size = prop.get_global_mem_size()/1000000;
+
+ fprintf(stderr, "|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
+ name.c_str(), version.c_str(), prop.get_max_compute_units(),
+ prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
+ global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
+}
+
+void ggml_backend_sycl_print_sycl_devices() {
+ GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
+ int device_count = dpct::dev_mgr::instance().device_count();
+ std::map<std::string, size_t> DeviceNums;
+ fprintf(stderr, "found %d SYCL devices:\n", device_count);
+ fprintf(stderr, "| | | | |Max | |Max |Global | |\n");
+ fprintf(stderr, "| | | | |compute|Max work|sub |mem | |\n");
+ fprintf(stderr, "|ID| Device Type| Name|Version|units |group |group|size | Driver version|\n");
+ fprintf(stderr, "|--|-------------------|---------------------------------------|-------|-------|--------|-----|-------|---------------------|\n");
+ for (int id = 0; id < device_count; ++id) {
+ sycl::device device = dpct::dev_mgr::instance().get_device(id);
+ sycl::backend backend = device.get_backend();
+ std::string backend_type = get_device_backend_and_type(device);
+ int type_id=DeviceNums[backend_type]++;
+ std::stringstream device_type;
+ device_type << "[" << backend_type << ":" << std::to_string(type_id) << "]";
+ print_device_detail(id, device, device_type.str());
+ }
+}
+
+static inline int get_sycl_env(const char *env_name, int default_val) {
+ char *user_device_string = getenv(env_name);
+ int user_number = default_val;
+
+ unsigned n;
+ if (user_device_string != NULL &&
+ sscanf(user_device_string, " %u", &n) == 1) {
+ user_number = (int)n;
+ } else {
+ user_number = default_val;
+ }
+ return user_number;
+}
+
+static void ggml_check_sycl() try {
+ static bool initialized = false;
+
+ if (!initialized) {
+ fprintf(stderr, "[SYCL] call ggml_check_sycl\n");
+ g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
+
+ fprintf(stderr, "%s: GGML_SYCL_DEBUG: %d\n", __func__, g_ggml_sycl_debug);
+
+#if defined(GGML_SYCL_F16)
+ fprintf(stderr, "%s: GGML_SYCL_F16: yes\n", __func__);
+#else
+ fprintf(stderr, "%s: GGML_SYCL_F16: no\n", __func__);
+#endif
+
+/* NOT REMOVE, keep it for next optimize for XMX.
+#if defined(SYCL_USE_XMX)
+ fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__);
+#else
+ fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
+#endif
+*/
+
+ if (CHECK_TRY_ERROR(g_all_sycl_device_count =
+ dpct::dev_mgr::instance().device_count()) != 0) {
+ initialized = true;
+ g_sycl_loaded = false;
+ return;
+ }
+ GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES);
+ ggml_backend_sycl_print_sycl_devices();
+ initialized = true;
+ g_sycl_loaded = true;
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static ggml_sycl_device_info ggml_sycl_init() {
+ ggml_sycl_device_info info = {};
+
+ info.device_count = dpct::dev_mgr::instance().device_count();
+ if (info.device_count == 0) {
+ fprintf(stderr, "%s: failed to initialize " GGML_SYCL_NAME ": %s\n", __func__);
+ return info;
+ }
+
+ GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES);
+
+ int64_t total_vram = 0;
+#if defined(GGML_SYCL_FORCE_MMQ)
+ fprintf(stderr, "%s: GGML_SYCL_FORCE_MMQ: yes\n", __func__);
+#else
+ fprintf(stderr, "%s: GGML_SYCL_FORCE_MMQ: no\n", __func__);
+#endif
+#if defined(SYCL_USE_XMX)
+ fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__);
+#else
+ fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
+#endif
+ fprintf(stderr, "%s: found %d " GGML_SYCL_NAME " devices:\n", __func__, info.device_count);
+
+ for (int i = 0; i < info.device_count; ++i) {
+ info.devices[i].vmm = 0;
+ dpct::device_info prop;
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
+ prop, dpct::dev_mgr::instance().get_device(i))));
+
+ info.default_tensor_split[i] = total_vram;
+ total_vram += prop.get_global_mem_size();
+
+ info.devices[i].cc =
+ 100 * prop.get_major_version() + 10 * prop.get_minor_version();
+
+ info.max_work_group_sizes[i] = prop.get_max_work_group_size();
+ }
+
+ for (int id = 0; id < info.device_count; ++id) {
+ info.default_tensor_split[id] /= total_vram;
+ }
+ return info;
+}
+
+const ggml_sycl_device_info & ggml_sycl_info() {
+ static ggml_sycl_device_info info = ggml_sycl_init();
+ return info;
+}
+
+/*
+device_index: device index from 0 to n (continue numbers).
+ It is used for device select/set in SYCL backend internal data structure.
+*/
+inline void check_allow_gpu_index(const int device_index) {
+ if (device_index >= ggml_sycl_info().device_count) {
+ char error_buf[256];
+ snprintf(
+ error_buf,
+ sizeof(error_buf),
+ "%s error: device_index:%d is out of range: [0-%d]",
+ __func__,
+ device_index,
+ ggml_sycl_info().device_count - 1);
+ fprintf(stderr, "%s\n", error_buf);
+ assert(false);
+ }
+}
+
+// buffer pool for sycl (legacy)
+struct ggml_sycl_pool_leg : public ggml_sycl_pool {
+ static const int MAX_SYCL_BUFFERS = 256;
+
+ int device;
+ queue_ptr qptr;
+ struct ggml_sycl_buffer {
+ void * ptr = nullptr;
+ size_t size = 0;
+ };
+
+ ggml_sycl_buffer buffer_pool[MAX_SYCL_BUFFERS] = {};
+ size_t pool_size = 0;
+
+ explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) :
+ qptr(qptr_),
+ device(device_) {
+ }
+
+ ~ggml_sycl_pool_leg() {
+ for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
+ ggml_sycl_buffer & b = buffer_pool[i];
+ if (b.ptr != nullptr) {
+ SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
+ pool_size -= b.size;
+ }
+ }
+ GGML_ASSERT(pool_size == 0);
+ }
+
+ void * alloc(size_t size, size_t * actual_size) override {
+#ifdef DEBUG_sycl_MALLOC
+ int nnz = 0;
+ size_t max_size = 0;
+#endif
+ size_t best_diff = 1ull << 36;
+ int ibest = -1;
+ for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
+ ggml_sycl_buffer& b = buffer_pool[i];
+ if (b.ptr != nullptr) {
+#ifdef DEBUG_sycl_MALLOC
+ ++nnz;
+ if (b.size > max_size) max_size = b.size;
+#endif
+ if (b.size >= size) {
+ size_t diff = b.size - size;
+ if (diff < best_diff) {
+ best_diff = diff;
+ ibest = i;
+ if (!best_diff) {
+ void * ptr = b.ptr;
+ *actual_size = b.size;
+ b.ptr = nullptr;
+ b.size = 0;
+ return ptr;
+ }
+ }
+ }
+ }
+ }
+ if (ibest >= 0) {
+ ggml_sycl_buffer& b = buffer_pool[ibest];
+ void * ptr = b.ptr;
+ *actual_size = b.size;
+ b.ptr = nullptr;
+ b.size = 0;
+ return ptr;
+ }
+ void * ptr;
+ size_t look_ahead_size = (size_t) (1.05 * size);
+
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device(
+ look_ahead_size, *qptr)));
+ *actual_size = look_ahead_size;
+ pool_size += look_ahead_size;
+
+ #ifdef DEBUG_SYCL_MALLOC
+ fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
+ (uint32_t)(max_size/1024/1024), (uint32_t)(g_sycl_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
+ #endif
+ // GGML_SYCL_DEBUG("ggml_sycl_pool_malloc_leg look_ahead_size=%lu, return %p\n", look_ahead_size, ptr);
+ return ptr;
+ }
+
+ void free(void * ptr, size_t size) override {
+ for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
+ ggml_sycl_buffer& b = buffer_pool[i];
+ if (b.ptr == nullptr) {
+ b.ptr = ptr;
+ b.size = size;
+ return;
+ }
+ }
+ fprintf(stderr, "WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n");
+ SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr)));
+ pool_size -= size;
+ }
+};
+
+std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
+ // TBD: NO VMM support
+ // if (ggml_sycl_info().devices[device].vmm) {
+ // return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(device));
+ // }
+ return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device));
+}
+
+// TBD pool with virtual memory management
+// struct ggml_sycl_pool_vmm : public ggml_sycl_pool
+
+static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
+ const struct ggml_tensor *src,
+ int64_t i3, int64_t i2,
+ int64_t i1_low, int64_t i1_high,
+ queue_ptr stream) try {
+
+ dpct::memcpy_direction kind;
+ char * src_ptr;
+ if (src->backend == GGML_BACKEND_TYPE_CPU) {
+ kind = dpct::host_to_device;
+ src_ptr = (char *) src->data;
+ // GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr);
+ } else if (src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
+ GGML_ASSERT(src->backend != GGML_BACKEND_TYPE_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
+ kind = dpct::device_to_device;
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
+ int id;
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ id = get_current_device_id()));
+ // GGML_SYCL_DEBUG("current device index %d\n", id);
+ src_ptr = (char *) extra->data_device[id];
+ } else {
+ // GGML_SYCL_DEBUG("GGML_ASSERT(false)\n");
+ GGML_ASSERT(false);
+ }
+ char * dst_ptr = (char *) dst;
+
+ GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);
+ GGML_TENSOR_LOCALS(int64_t, nb, src, nb);
+ const enum ggml_type type = src->type;
+ const int64_t ts = ggml_type_size(type);
+ const int64_t bs = ggml_blck_size(type);
+ int64_t i1_diff = i1_high - i1_low;
+
+ const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
+ if (nb0 == ts && nb1 == ts*ne0/bs) {
+ // GGML_SYCL_DEBUG("stream->memcpy: dst_ptr=%p, x=%p, size=%lu\n", dst_ptr, x, i1_diff * nb1);
+ // return CHECK_TRY_ERROR(stream->memcpy(dst_ptr, x, i1_diff * nb1));
+ return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1,
+ kind, *stream));
+
+ } else if (nb0 == ts) {
+ return CHECK_TRY_ERROR(
+ dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1,
+ ts * ne0 / bs, i1_diff, kind, *stream));
+ } else {
+ for (int64_t i1 = 0; i1 < i1_diff; i1++) {
+ const void * rx = (const void *) ((const char *) x + i1*nb1);
+ void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
+ // pretend the row is a matrix with cols=1
+ dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
+ rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream));
+ /*
+ DPCT1001:85: The statement could not be removed.
+ */
+ /*
+ DPCT1000:86: Error handling if-stmt was detected but could not be
+ rewritten.
+ */
+ if (r != 0) return r;
+ }
+ return 0;
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_d, const float *src1_d,
+ float *dst_d, const queue_ptr &stream) {
+
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+ GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+ GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
+
+ const int32_t * src1_i32 = (const int32_t *) src1_d;
+
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
+ src1_i32, dst_d, stream);
+ break;
+ case GGML_TYPE_F32:
+ get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ break;
+ case GGML_TYPE_Q4_0:
+ get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ break;
+ case GGML_TYPE_Q4_1:
+ get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ break;
+ case GGML_TYPE_Q5_0:
+ get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ break;
+ case GGML_TYPE_Q5_1:
+ get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ break;
+ case GGML_TYPE_Q8_0:
+ get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+ break;
+ default:
+ // TODO: k-quants
+ fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
+ GGML_ASSERT(false);
+ break;
+ }
+}
+
+template <class op>
+inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_dd, const float *src1_dd,
+ float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+ op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
+ (sycl::half *)dst_dd, main_stream);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+ op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
+ main_stream);
+ } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
+ op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
+ main_stream);
+ } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
+ op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
+ main_stream);
+ } else {
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
+ ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+ GGML_ASSERT(false);
+ }
+}
+
+static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_d, const float *src1_d,
+ float *dst_d,
+ const queue_ptr &main_stream) {
+
+ ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
+
+ (void) src1;
+ (void) src1_d;
+}
+
+inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+ ggml_tensor *dst, const float *src0_dd,
+ const float *src1_dd, float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+}
+
+inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+ ggml_tensor *dst, const float *src0_dd,
+ const float *src1_dd, float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
+
+ int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
+ int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
+ // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
+ int offset = dst->op_params[3] / 4; // offset in bytes
+
+ acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
+
+ (void) dst;
+}
+
+inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+ ggml_tensor *dst, const float *src0_dd,
+ const float *src1_dd, float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+}
+
+inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+ ggml_tensor *dst, const float *src0_dd,
+ const float *src1_dd, float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+}
+
+inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+ ggml_tensor *dst, const float *src0_dd,
+ const float *src1_dd, float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+ ggml_tensor *dst, const float *src0_dd,
+ const float *src1_dd, float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_dd, const float *src1_dd,
+ float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+ ggml_tensor *dst, const float *src0_dd,
+ const float *src1_dd, float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+ ggml_tensor *dst, const float *src0_dd,
+ const float *src1_dd, float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+static void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_dd, const float *src1_dd,
+ float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+static void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_dd, const float *src1_dd,
+ float *dst_dd, const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_dd, const float *src1_dd,
+ float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ float negative_slope;
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
+
+ leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+ ggml_tensor *dst, const float *src0_dd,
+ const float *src1_dd, float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_dd, const float *src1_dd,
+ float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const float sf0 = (float)dst->ne[0]/src0->ne[0];
+ const float sf1 = (float)dst->ne[1]/src0->ne[1];
+ const float sf2 = (float)dst->ne[2]/src0->ne[2];
+ const float sf3 = (float)dst->ne[3]/src0->ne[3];
+
+ upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
+ main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+ ggml_tensor *dst, const float *src0_dd,
+ const float *src1_dd, float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+ pad_f32_sycl(src0_dd, dst_dd,
+ src0->ne[0], src0->ne[1], src0->ne[2],
+ dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split) {
+ int64_t min_compute_capability = INT_MAX;
+ int64_t max_compute_capability = INT_MIN;
+ for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+ if (tensor_split[i] < (i + 1 < ggml_sycl_info().device_count ? tensor_split[i + 1] : 1.0f)) {
+ if (min_compute_capability > ggml_sycl_info().devices[i].cc) {
+ min_compute_capability = ggml_sycl_info().devices[i].cc;
+ }
+ if (max_compute_capability < ggml_sycl_info().devices[i].cc) {
+ max_compute_capability = ggml_sycl_info().devices[i].cc;
+ }
+ }
+ }
+
+ switch(type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ return max_compute_capability >= VER_GEN9 ? 128 : 64;
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ return 64;
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ return 1;
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ4_NL:
+ return max_compute_capability >= VER_GEN9 ? 128 : 64;
+ case GGML_TYPE_IQ3_S:
+ return max_compute_capability >= VER_GEN9 ? 128 : 64;
+ case GGML_TYPE_Q6_K:
+ return 64;
+ default:
+ GGML_ASSERT(false);
+ }
+
+}
+
+inline void ggml_sycl_op_mul_mat_sycl(
+ ggml_backend_sycl_context & ctx,
+ const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+ const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
+ float *dst_dd_i, const int64_t row_low, const int64_t row_high,
+ const int64_t src1_ncols, const int64_t src1_padded_row_size,
+ const queue_ptr &stream) try {
+
+ GGML_ASSERT(src0_dd_i != nullptr);
+ GGML_ASSERT(src1_ddf_i != nullptr);
+ GGML_ASSERT(dst_dd_i != nullptr);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne10 = src1->ne[0];
+
+ const int64_t ne0 = dst->ne[0];
+
+ const int64_t row_diff = row_high - row_low;
+
+ int id;
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(id = get_current_device_id()));
+
+ // the main device has a larger memory buffer to hold the results from all GPUs
+ // ldc == nrows of the matrix that cuBLAS writes into
+ int ldc = id == ctx.device ? ne0 : row_diff;
+
+#ifdef GGML_SYCL_F16
+ bool use_fp16 = true; // TODO(Yu) SYCL capability check
+#else
+ bool use_fp16 = false;
+#endif
+ if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
+ use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] &&
+ dst->op_params[0] == GGML_PREC_DEFAULT) {
+
+ // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n");
+ ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool());
+ if (src0->type != GGML_TYPE_F16) {
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type);
+ GGML_ASSERT(to_fp16_sycl != nullptr);
+ size_t ne = row_diff*ne00;
+ src0_as_f16.alloc(ne);
+ to_fp16_sycl(src0_dd_i, src0_as_f16.get(), ne, stream);
+ }
+ const sycl::half *src0_ptr = src0->type == GGML_TYPE_F16
+ ? (const sycl::half *)src0_dd_i
+ : src0_as_f16.get();
+
+ ggml_sycl_pool_alloc<sycl::half> src1_as_f16(ctx.pool());
+ if (src1->type != GGML_TYPE_F16) {
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
+ GGML_ASSERT(to_fp16_sycl != nullptr);
+ size_t ne = src1_ncols*ne10;
+ src1_as_f16.alloc(ne);
+ to_fp16_sycl(src1_ddf_i, src1_as_f16.get(), ne, stream);
+ }
+ const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
+ ? (const sycl::half *)src1->data + src1_padded_row_size
+ : src1_as_f16.get();
+ ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
+
+ const sycl::half alpha_f16 = 1.0f;
+ const sycl::half beta_f16 = 0.0f;
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
+ *stream, oneapi::mkl::transpose::trans,
+ oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
+ &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
+ src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
+ dst_f16.get(), dpct::library_data_t::real_half, ldc,
+ dpct::library_data_t::real_half)));
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
+ to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+ }
+ else {
+ // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
+ ggml_sycl_pool_alloc<float> src0_ddq_as_f32(ctx.pool());
+ ggml_sycl_pool_alloc<float> src1_ddq_as_f32(ctx.pool());
+ if (src0->type != GGML_TYPE_F32) {
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type);
+ GGML_ASSERT(to_fp32_sycl != nullptr);
+ src0_ddq_as_f32.alloc(row_diff*ne00);
+ to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
+ }
+ if (src1->type != GGML_TYPE_F32) {
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type);
+ GGML_ASSERT(to_fp32_sycl != nullptr);
+ src1_ddq_as_f32.alloc(src1_ncols*ne10);
+ to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
+ }
+ const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
+ const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
+
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+
+ SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
+ *stream, oneapi::mkl::transpose::trans,
+ oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
+ dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
+ src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
+ dst_dd_i, ldc)));
+ }
+ (void) dst;
+ (void) src1_ddq_i;
+ (void) src1_padded_row_size;
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_dd, const float *src1_dd,
+ float *dst_dd, const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int32_t * opts = (const int32_t *)dst->op_params;
+ enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
+ const int k0 = opts[1];
+ const int k1 = opts[2];
+ const int s0 = opts[3];
+ const int s1 = opts[4];
+ const int p0 = opts[5];
+ const int p1 = opts[6];
+
+ const int64_t IH = src0->ne[1];
+ const int64_t IW = src0->ne[0];
+
+ const int64_t N = dst->ne[3];
+ const int64_t OC = dst->ne[2];
+ const int64_t OH = dst->ne[1];
+ const int64_t OW = dst->ne[0];
+
+ const int parallel_elements = N * OC * OH * OW;
+ const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;
+ sycl::range<3> block_nums(1, 1, num_blocks);
+ main_stream->parallel_for(
+ sycl::nd_range<3>(block_nums *
+ sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ pool2d_nchw_kernel(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0,
+ parallel_elements, src0_dd, dst_dd, op,
+ item_ct1);
+ });
+
+ (void) src1;
+ (void) src1_dd;
+}
+
+inline void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_dd, const float *src1_dd,
+ float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
+
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+ const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+ const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
+
+ const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
+
+ const int64_t IC = src1->ne[is_2D ? 2 : 1];
+ const int64_t IH = is_2D ? src1->ne[1] : 1;
+ const int64_t IW = src1->ne[0];
+
+ const int64_t KH = is_2D ? src0->ne[1] : 1;
+ const int64_t KW = src0->ne[0];
+
+ const int64_t OH = is_2D ? dst->ne[2] : 1;
+ const int64_t OW = dst->ne[1];
+
+ const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+
+ if (dst->type == GGML_TYPE_F16) {
+ im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+ } else {
+ im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
+ }
+
+ (void) src0;
+ (void) src0_dd;
+}
+
+inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_dd, const float *src1_dd,
+ float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_dd, const float *src1_dd,
+ float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+
+ argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1,
+ ggml_tensor *dst, const float *src0_dd,
+ const float *src1_dd, float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int nrows0 = ggml_nrows(src0);
+
+ const int n_past = ((int32_t *) dst->op_params)[0];
+
+ diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+ ggml_tensor *dst, const float *src0_dd,
+ const float *src1_dd, float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ float scale;
+ memcpy(&scale, dst->op_params, sizeof(float));
+
+ scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
+ /*
+ DPCT1010:87: SYCL uses exceptions to report errors and does not use the
+ error codes. The call was replaced with 0. You need to rewrite this code.
+ */
+ SYCL_CHECK(0);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+ ggml_tensor *dst, const float *src0_dd,
+ const float *src1_dd, float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ float min;
+ float max;
+ memcpy(&min, dst->op_params, sizeof(float));
+ memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
+
+ clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
+ /*
+ DPCT1010:88: SYCL uses exceptions to report errors and does not use the
+ error codes. The call was replaced with 0. You need to rewrite this code.
+ */
+ SYCL_CHECK(0);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+static void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const ggml_sycl_op_flatten_t op) try {
+ const int64_t nrows0 = ggml_nrows(src0);
+
+ const bool use_src1 = src1 != nullptr;
+ const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
+
+ GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+ GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+
+ ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+ ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+
+ // dd = data device
+ float * src0_ddf = (float *) src0->data;
+ float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
+ float * dst_ddf = (float *) dst->data;
+
+ ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
+ ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
+ ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
+
+ ggml_sycl_set_device(ctx.device);
+ queue_ptr main_stream = ctx.stream();
+ // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
+ // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
+
+ // do the computation
+ op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
+ // print_ggml_tensor("tensor", dst);
+}
+catch (sycl::exception const &exc) {
+
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
+ static bool peer_access_enabled = false;
+
+ const bool enable_peer_access = n_tokens <= GGML_SYCL_PEER_MAX_BATCH_SIZE;
+
+ if (peer_access_enabled == enable_peer_access) {
+ return;
+ }
+
+#ifdef NDEBUG
+ for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+ SYCL_CHECK(ggml_sycl_set_device(i));
+ }
+
+ for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+ SYCL_CHECK(ggml_sycl_set_device(i));
+
+ for (int id_other = 0; id_other < ggml_sycl_info().device_count; ++id_other) {
+ if (i == id_other) {
+ continue;
+ }
+ if (i != main_device && id_other != main_device) {
+ continue;
+ }
+
+ // int can_access_peer;
+ // SYCL_CHECK(syclDeviceCanAccessPeer(&can_access_peer, id, id_other));
+ // if (can_access_peer) {
+ // if (enable_peer_access) {
+ // SYCL_CHECK(syclDeviceEnablePeerAccess(id_other, 0));
+ // } else {
+ // SYCL_CHECK(syclDeviceDisablePeerAccess(id_other));
+ // }
+ // }
+ }
+ }
+#endif // NDEBUG
+
+ peer_access_enabled = enable_peer_access;
+}
+
+struct ggml_backend_sycl_split_buffer_type_context {
+ std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split;
+};
+
+static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ ggml_sycl_op_mul_mat_t op,
+ const bool convert_src1_to_q8_1) try {
+
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
+
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
+ const int64_t nrows1 = ggml_nrows(src1);
+
+ GGML_ASSERT(ne03 == ne13);
+
+ const int64_t ne0 = dst->ne[0];
+ const int64_t ne1 = dst->ne[1];
+
+ const int nb2 = dst->nb[2];
+ const int nb3 = dst->nb[3];
+
+ GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+ GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
+
+ GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
+
+ const int64_t i02_divisor = ne12 / ne02;
+
+ const size_t src0_ts = ggml_type_size(src0->type);
+ const size_t src0_bs = ggml_blck_size(src0->type);
+ const size_t q8_1_ts = sizeof(block_q8_1);
+ const size_t q8_1_bs = QK8_1;
+
+ ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+ ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+
+ const bool src0_is_contiguous = ggml_is_contiguous(src0);
+ const bool src1_is_contiguous = ggml_is_contiguous(src1);
+
+ int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
+
+ const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
+ GGML_ASSERT(!(split && ne02 > 1));
+ GGML_ASSERT(!(split && ne03 > 1));
+ GGML_ASSERT(!(split && ne02 < ne12));
+
+ std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split;
+ if (split) {
+ // TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_TYPE_GPU_SPLIT check
+ // GGML_ASSERT(src0->buffer != nullptr && src0->buffer->buft == ...);
+ ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
+ tensor_split = buft_ctx->tensor_split;
+ }
+
+ struct dev_data {
+ ggml_sycl_pool_alloc<char> src0_dd_alloc;
+ ggml_sycl_pool_alloc<float> src1_ddf_alloc;
+ ggml_sycl_pool_alloc<char> src1_ddq_alloc;
+ ggml_sycl_pool_alloc<float> dst_dd_alloc;
+
+ char *src0_dd = nullptr;
+ float *src1_ddf = nullptr; // float
+ char *src1_ddq = nullptr; // q8_1
+ float *dst_dd = nullptr;
+
+ int64_t row_low;
+ int64_t row_high;
+ };
+
+ dev_data dev[GGML_SYCL_MAX_DEVICES];
+
+ int used_devices = 0;
+ queue_ptr main_stream = ctx.stream();
+
+ for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+ // by default, use all rows
+ dev[i].row_low = 0;
+ dev[i].row_high = ne01;
+
+ // for multi GPU, get the row boundaries from tensor split
+ // and round to mul_mat_q tile sizes
+ if (split) {
+ const int64_t rounding = get_row_rounding(src0->type, tensor_split);
+
+ if (i != 0) {
+ dev[i].row_low = ne01*tensor_split[i];
+ if (dev[i].row_low < ne01) {
+ dev[i].row_low -= dev[i].row_low % rounding;
+ }
+ }
+
+ if (i != ggml_sycl_info().device_count - 1) {
+ dev[i].row_high = ne01*tensor_split[i + 1];
+ if (dev[i].row_high < ne01) {
+ dev[i].row_high -= dev[i].row_high % rounding;
+ }
+ }
+ }
+ }
+
+ for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+ if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
+ continue;
+ }
+
+ used_devices++;
+
+ const bool src1_on_device = i == ctx.device;
+ const bool dst_on_device = i == ctx.device;
+
+ ggml_sycl_set_device(i);
+ queue_ptr stream = ctx.stream(i, 0);
+
+ if (src0_is_contiguous) {
+ dev[i].src0_dd = (char *) src0->data;
+ } else {
+ dev[i].src0_dd = dev[i].src0_dd_alloc.alloc(ctx.pool(i), ggml_nbytes(src0));
+ }
+
+ if (src1_on_device && src1_is_contiguous) {
+ dev[i].src1_ddf = (float *) src1->data;
+ } else {
+ dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
+ }
+
+ if (convert_src1_to_q8_1) {
+ dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
+
+ if (src1_on_device && src1_is_contiguous) {
+ quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
+ /*
+ DPCT1010:90: SYCL uses exceptions to report errors and does not
+ use the error codes. The call was replaced with 0. You need to
+ rewrite this code.
+ */
+ SYCL_CHECK(0);
+ }
+ }
+
+ if (dst_on_device) {
+ dev[i].dst_dd = (float *) dst->data;
+ } else {
+ const size_t size_dst_ddf = split ? (dev[i].row_high - dev[i].row_low)*ne1 : ggml_nelements(dst);
+ dev[i].dst_dd = dev[i].dst_dd_alloc.alloc(ctx.pool(i), size_dst_ddf);
+ }
+ }
+
+ // if multiple devices are used they need to wait for the main device
+ // here an event is recorded that signals that the main device has finished calculating the input data
+ if (split && used_devices > 1) {
+ ggml_sycl_set_device(ctx.device);
+ /*
+ DPCT1024:91: The original code returned the error code that was further
+ consumed by the program logic. This original code was replaced with 0.
+ You may need to rewrite the program logic consuming the error code.
+ */
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ *src0_extra->events[ctx.device][0] =
+ ctx.stream()->ext_oneapi_submit_barrier()));
+ }
+
+ const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
+ for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
+ const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
+ const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
+
+ for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+ if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
+ continue;
+ }
+
+ const bool src1_on_device = i == ctx.device;
+ const bool dst_on_device = i == ctx.device;
+ const int64_t row_diff = dev[i].row_high - dev[i].row_low;
+
+ ggml_sycl_set_device(i);
+ queue_ptr stream = ctx.stream(i, is);
+
+ // wait for main GPU data if necessary
+ if (split && (i != ctx.device || is != 0)) {
+ /*
+ DPCT1009:163: SYCL uses exceptions to report errors and does not
+ use the error codes. The original code was commented out and a
+ warning string was inserted. You need to rewrite this code.
+ */
+ SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
+ {*src0_extra->events[ctx.device][0]})));
+ }
+
+ for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
+ const int64_t i03 = i0 / ne12;
+ const int64_t i02 = i0 % ne12;
+
+ const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
+
+ // for split tensors the data begins at i0 == i0_offset_low
+ char * src0_dd_i = dev[i].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
+ float * src1_ddf_i = dev[i].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
+ char * src1_ddq_i = dev[i].src1_ddq + src1_ddq_i_offset;
+ float * dst_dd_i = dev[i].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
+
+ // the main device memory buffer can be on VRAM scratch, with space for all partial results
+ // in that case an offset on dst_ddf_i is needed
+ if (i == ctx.device) {
+ dst_dd_i += dev[i].row_low; // offset is 0 if no tensor split
+ }
+
+ // copy src0, src1 to device if necessary
+ if (src1_is_contiguous) {
+ if (i != ctx.device) {
+ if (convert_src1_to_q8_1) {
+ char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
+ SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
+ src1_ddq_i, src1_ddq_i_source,
+ src1_ncols * src1_padded_col_size * q8_1_ts /
+ q8_1_bs).wait()));
+ } else {
+
+ float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
+ src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
+
+ SYCL_CHECK(CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream,
+ src1_ddf_i, src1_ddf_i_source,
+ src1_ncols * ne10 * sizeof(float))));
+ }
+ }
+ } else if (src1_on_device && !src1_is_contiguous) {
+ SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
+ src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ if (convert_src1_to_q8_1 && !src1_is_contiguous) {
+ quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
+ /*
+ DPCT1010:92: SYCL uses exceptions to report errors and does
+ not use the error codes. The call was replaced with 0. You
+ need to rewrite this code.
+ */
+ SYCL_CHECK(0);
+ }
+
+ if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
+ SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[i].row_low, dev[i].row_high, stream));
+ }
+ if (src1->type == GGML_TYPE_F16) {
+ src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10;
+ }
+ // do the computation
+ SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
+ dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
+ /*
+ DPCT1010:93: SYCL uses exceptions to report errors and does not
+ use the error codes. The call was replaced with 0. You need to
+ rewrite this code.
+ */
+ SYCL_CHECK(0);
+
+ // copy dst to host or other device if necessary
+ if (!dst_on_device) {
+ void * dst_off_device = dst->data;
+ if (split) {
+ // src0 = weight matrix is saved as a transposed matrix for better memory layout.
+ // dst is NOT transposed.
+ // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
+ // Instead they need to be copied to the correct slice in ne0 = dst row index.
+ // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
+ float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+ GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+ dhf_dst_i += src1_col_0*ne0 + dev[i].row_low;
+
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
+ dhf_dst_i, ne0 * sizeof(float), dst_dd_i,
+ row_diff * sizeof(float), row_diff * sizeof(float),
+ src1_ncols, dpct::device_to_device, *stream)));
+ } else {
+ float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+ GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+ dhf_dst_i += src1_col_0*ne0;
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ stream->memcpy(dhf_dst_i, dst_dd_i,
+ src1_ncols * ne0 * sizeof(float)).wait()));
+ }
+ }
+
+ // add event for the main device to wait on until other device is done
+ if (split && (i != ctx.device || is != 0)) {
+ /*
+ DPCT1024:94: The original code returned the error code that
+ was further consumed by the program logic. This original
+ code was replaced with 0. You may need to rewrite the
+ program logic consuming the error code.
+ */
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ *src0_extra->events[i][is] =
+ stream->ext_oneapi_submit_barrier()));
+ }
+ }
+ }
+ }
+
+ // main device waits for all other devices to be finished
+ if (split && ggml_sycl_info().device_count > 1) {
+ int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
+ is_max = is_max <= GGML_SYCL_MAX_STREAMS ? is_max : GGML_SYCL_MAX_STREAMS;
+
+ ggml_sycl_set_device(ctx.device);
+ for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+ if (dev[i].row_low == dev[i].row_high) {
+ continue;
+ }
+ for (int64_t is = 0; is < is_max; ++is) {
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ ctx.stream()->ext_oneapi_submit_barrier(
+ {*src0_extra->events[i][is]})));
+ }
+ }
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+
+static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_repeat);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_get_rows);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_add);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_mul);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_div);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_silu);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu_quick);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_tanh);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_relu);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardswish);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqr);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pad);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+
+static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_SYCL_DEBUG("call %s\n", __func__);
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm);
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1,
+ ggml_tensor *dst) try {
+ GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
+ GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+ GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
+ GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+
+ const int64_t ne12 = src1->ne[2];
+
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+ queue_ptr main_stream = ctx.stream();
+
+ void * src0_ddq = src0->data;
+ float * src1_ddf = (float *) src1->data;
+ float * dst_ddf = (float *) dst->data;
+
+ ggml_mul_mat_p021_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1,
+ ggml_tensor *dst) try {
+ GGML_ASSERT(!ggml_is_transposed(src0));
+ GGML_ASSERT(!ggml_is_transposed(src1));
+ GGML_ASSERT(!ggml_is_permuted(src0));
+ GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+
+ const int64_t nb01 = src0->nb[1];
+ const int64_t nb02 = src0->nb[2];
+
+ const int64_t ne12 = src1->ne[2];
+
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+ queue_ptr main_stream = ctx.stream();
+
+ void * src0_ddq = src0->data;
+ float * src1_ddf = (float *) src1->data;
+ float * dst_ddf = (float *) dst->data;
+
+ const int64_t row_stride_x = nb01 / sizeof(sycl::half);
+ const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
+
+ ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
+ const sycl::half *src1_as_f16, char *dst,
+ const void **ptrs_src, void **ptrs_dst,
+ int64_t ne12, int64_t ne13, int64_t ne23,
+ size_t nb02, size_t nb03, size_t nb12,
+ size_t nb13, size_t nbd2, size_t nbd3,
+ int64_t r2, int64_t r3,
+ const sycl::nd_item<3> &item_ct1) {
+ int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
+ item_ct1.get_local_id(2);
+ int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+
+ if (i13 >= ne13 || i12 >= ne12) {
+ return;
+ }
+
+ int64_t i03 = i13 / r3;
+ int64_t i02 = i12 / r2;
+
+ ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
+ ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
+ ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
+}
+
+static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
+ const ggml_tensor *src0,
+ const ggml_tensor *src1,
+ ggml_tensor *dst) try {
+ GGML_ASSERT(!ggml_is_transposed(src0));
+ GGML_ASSERT(!ggml_is_transposed(src1));
+ GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int64_t ne_dst = ggml_nelements(dst);
+
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+ queue_ptr main_stream = ctx.stream();;
+
+ void * src0_ddq = src0->data;
+ sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
+ float * src1_ddf = (float *) src1->data;
+ float * dst_ddf = (float *) dst->data;
+
+ // convert src1 to fp16
+ ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
+ if (src1->type != GGML_TYPE_F16) {
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
+ const int64_t ne_src1 = ggml_nelements(src1);
+ src1_f16_alloc.alloc(ne_src1);
+ GGML_ASSERT(to_fp16_sycl != nullptr);
+ to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
+ }
+ sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
+ : src1_f16_alloc.get();
+
+ char * dst_t;
+
+ dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
+ dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
+
+ // dst strides
+ size_t nbd2 = dst->nb[2];
+ size_t nbd3 = dst->nb[3];
+
+ const float alpha_f32 = 1.0f;
+ const float beta_f32 = 0.0f;
+
+ const void * alpha = &alpha_f32;
+ const void * beta = &beta_f32;
+
+ dst_t = (char *) dst_ddf;
+
+ GGML_ASSERT(ne12 % ne02 == 0);
+ GGML_ASSERT(ne13 % ne03 == 0);
+
+ // broadcast factors
+ const int64_t r2 = ne12/ne02;
+ const int64_t r3 = ne13/ne03;
+
+ if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
+ // there is no broadcast and src0, src1 are contiguous across dims 2, 3
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
+ *main_stream, oneapi::mkl::transpose::trans,
+ oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
+ (const char *)src0_as_f16, dpct::library_data_t::real_half,
+ nb01 / nb00, nb02 / nb00,
+ (const char *)src1_f16, dpct::library_data_t::real_half,
+ nb11 / nb10, nb12 / nb10, beta,
+ (char *)dst_t, cu_data_type, ne01, nb2 / nb0,
+ ne12 * ne13, cu_compute_type)));
+ } else {
+ const int ne23 = ne12*ne13;
+
+ ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
+ ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
+
+ sycl::range<3> block_dims(1, ne12, ne13);
+ /*
+ DPCT1049:47: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(main_stream->get_device(),
+ {sycl::aspect::fp16});
+
+ main_stream->submit([&](sycl::handler &cgh) {
+ const void **ptrs_src_get = ptrs_src.get();
+ void **ptrs_dst_get = ptrs_dst.get();
+ size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
+ size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
+ cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ k_compute_batched_ptrs(
+ src0_as_f16, src1_f16,
+ dst_t, ptrs_src_get,
+ ptrs_dst_get, ne12, ne13, ne23,
+ nb02, nb03, nb12_scaled, nb13_scaled,
+ nbd2, nbd3, r2, r3, item_ct1);
+ });
+ });
+ }
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
+ *main_stream, oneapi::mkl::transpose::trans,
+ oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
+ (const void **)(ptrs_src.get() + 0 * ne23),
+ dpct::library_data_t::real_half, nb01 / nb00,
+ (const void **)(ptrs_src.get() + 1 * ne23),
+ dpct::library_data_t::real_half, nb11 / nb10, beta,
+ (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
+ cu_compute_type)));
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
+ // TODO: accuracy issues in MMQ
+ return false;
+}
+
+bool ggml_sycl_supports_dmmv(enum ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_F16:
+ return true;
+ default:
+ return false;
+ }
+}
+
+static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
+ int64_t min_compute_capability = INT_MAX;
+
+ if (split) {
+ ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
+ auto & tensor_split = buft_ctx->tensor_split;
+ for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
+ // skip devices that are not going to do any work:
+ if (tensor_split[id] >= (id + 1 < ggml_sycl_info().device_count ? tensor_split[id + 1] : 1.0f)) {
+ continue;
+ }
+
+ if (min_compute_capability > ggml_sycl_info().devices[id].cc) {
+ min_compute_capability = ggml_sycl_info().devices[id].cc;
+ }
+ }
+ } else {
+ min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
+ }
+
+ // check data types and tensor shapes for custom matrix multiplication kernels:
+ bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+ && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
+
+ bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+ && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+
+ bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+
+ // mmvq and mmq need the __dp4a instruction which is available for gen12+
+ // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
+ use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
+#ifdef SYCL_USE_XMX
+ use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
+#endif // SYCL_USE_XMX
+
+ // mmvq path is faster in the CUDA backend.
+ if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
+ use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
+
+ if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+ // KQ single-batch
+ ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst);
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
+ // KQV single-batch
+ ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+ // KQ + KQV multi-batch
+ ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
+ } else if (use_dequantize_mul_mat_vec) {
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
+ } else if (use_mul_mat_vec_q) {
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
+ } else if (use_mul_mat_q) {
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
+ } else {
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
+ }
+}
+
+
+struct mmid_row_mapping {
+ int32_t i1;
+ int32_t i2;
+};
+
+__dpct_inline__ static void k_copy_src1_to_contiguous(
+ const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
+ int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
+ const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
+ int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
+ const sycl::nd_item<3> &item_ct1, int &src1_row) {
+ int32_t iid1 = item_ct1.get_group(2);
+ int32_t id = item_ct1.get_group(1);
+
+ const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
+
+ if (row_id_i != i02) {
+ return;
+ }
+
+ const int64_t i11 = id % ne11;
+ const int64_t i12 = iid1;
+
+ if (item_ct1.get_local_id(2) == 0) {
+ src1_row =
+ dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
+ cur_src1_row, 1);
+ row_mapping[src1_row] = {id, iid1};
+ }
+ /*
+ DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
+ performance if there is no access to global memory.
+ */
+ item_ct1.barrier();
+
+ const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
+ float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
+
+#pragma unroll
+ for (int i = item_ct1.get_local_id(2); i < ne10;
+ i += item_ct1.get_local_range(2)) {
+ src1_row_contiguous[i] = src1_row_original[i];
+ }
+}
+
+__dpct_inline__ static void k_copy_dst_from_contiguous(
+ char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
+ const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
+ size_t nb2, const sycl::nd_item<3> &item_ct1) {
+ int32_t i = item_ct1.get_group(2);
+
+ const int32_t i1 = row_mapping[i].i1;
+ const int32_t i2 = row_mapping[i].i2;
+
+ const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
+ float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
+
+#pragma unroll
+ for (int j = item_ct1.get_local_id(2); j < ne0;
+ j += item_ct1.get_local_range(2)) {
+ dst_row_original[j] = dst_row_contiguous[j];
+ }
+}
+
+static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1,
+ ggml_tensor *dst) try {
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
+
+ const ggml_tensor *ids = dst->src[2];
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const queue_ptr stream = ctx.stream();
+
+ const int64_t n_as = ne02;
+ const int64_t n_ids = ids->ne[0];
+
+ std::vector<char> ids_host(ggml_nbytes(ids));
+ const char * ids_dev = (const char *) ids->data;
+
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
+ SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
+
+ ggml_tensor src0_row = *src0;
+ ggml_tensor src1_row = *src1;
+ ggml_tensor dst_row = *dst;
+
+ char *src0_original = (char *)src0->data;
+ char *src1_original = (char *)src1->data;
+ char *dst_original = (char *)dst->data;
+
+ src0_row.ne[2] = 1;
+ src0_row.ne[3] = 1;
+ src0_row.nb[3] = nb02;
+
+ src1_row.ne[1] = 1;
+ src1_row.ne[2] = 1;
+ src1_row.ne[3] = 1;
+ src1_row.nb[2] = nb11;
+ src1_row.nb[3] = nb11;
+
+ dst_row.ne[1] = 1;
+ dst_row.ne[2] = 1;
+ dst_row.ne[3] = 1;
+ dst_row.nb[2] = nb1;
+ dst_row.nb[3] = nb1;
+ if (ne12 == 1) {
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+ for (int64_t id = 0; id < n_ids; id++) {
+ const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
+
+ const int64_t i11 = id % ne11;
+ const int64_t i12 = iid1;
+
+ const int64_t i1 = id;
+ const int64_t i2 = i12;
+
+ src0_row.data = src0_original + i02*nb02;
+ src1_row.data = src1_original + + i11*nb11 + i12*nb12;
+ dst_row.data = dst_original + i1*nb1 + i2*nb2;
+
+ ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+ }
+ }
+ } else {
+ ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
+ ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
+
+ src1_row.data = src1_contiguous.get();
+ dst_row.data = dst_contiguous.get();
+
+ for (int64_t i02 = 0; i02 < n_as; i02++) {
+ int64_t num_src1_rows = 0;
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+ for (int64_t id = 0; id < n_ids; id++) {
+ const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+
+ GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
+
+ if (row_id_i != i02) {
+ continue;
+ }
+
+ num_src1_rows++;
+ }
+ }
+
+ if (num_src1_rows == 0) {
+ continue;
+ }
+
+
+ ggml_sycl_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
+ ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
+
+ {
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
+ sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 0> src1_row_acc(cgh);
+
+ char *__restrict src1_contiguous_get =
+ src1_contiguous.get();
+ int *__restrict dev_cur_src1_row_get =
+ dev_cur_src1_row.get();
+ mmid_row_mapping *__restrict dev_row_mapping_get =
+ dev_row_mapping.get();
+ size_t ids_nb_ct6 = ids->nb[1];
+ size_t ids_nb_ct7 = ids->nb[0];
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ k_copy_src1_to_contiguous(
+ src1_original, src1_contiguous_get,
+ dev_cur_src1_row_get,
+ dev_row_mapping_get, ids_dev, i02,
+ ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
+ item_ct1, src1_row_acc);
+ });
+ });
+ }
+
+ src0_row.data = src0_original + i02*nb02;
+
+ GGML_ASSERT(nb11 == sizeof(float)*ne10);
+ GGML_ASSERT(nb1 == sizeof(float)*ne0);
+ src1_row.ne[1] = num_src1_rows;
+
+ src1_row.nb[1] = nb11;
+ src1_row.nb[2] = num_src1_rows*nb11;
+ src1_row.nb[3] = num_src1_rows*nb11;
+
+ dst_row.ne[1] = num_src1_rows;
+ dst_row.nb[1] = nb1;
+ dst_row.nb[2] = num_src1_rows*nb1;
+ dst_row.nb[3] = num_src1_rows*nb1;
+
+ ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+
+ {
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
+ sycl::range<3> grid_dims(1, 1, num_src1_rows);
+ stream->submit([&](sycl::handler &cgh) {
+ const char *__restrict dst_contiguous_get =
+ dst_contiguous.get();
+ const mmid_row_mapping *__restrict dev_row_mapping_get =
+ dev_row_mapping.get();
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ k_copy_dst_from_contiguous(dst_original,
+ dst_contiguous_get,
+ dev_row_mapping_get,
+ ne0, nb1, nb2, item_ct1);
+ });
+ });
+ }
+ }
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_scale);
+}
+
+static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_clamp);
+}
+
+static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+ ggml_tensor *dst) try {
+ const int64_t ne = ggml_nelements(src0);
+ GGML_ASSERT(ne == ggml_nelements(src1));
+
+ GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
+ GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
+
+ GGML_TENSOR_BINARY_OP_LOCALS01;
+
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+ queue_ptr main_stream = ctx.stream();
+
+ char * src0_ddc = (char *) src0->data;
+ char * src1_ddc = (char *) src1->data;
+
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+ ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+ ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+ ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+ ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+ ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+ ggml_cpy_f16_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+ ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
+ ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
+ ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else {
+ fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
+ ggml_type_name(src0->type), ggml_type_name(src1->type));
+ GGML_ASSERT(false);
+ }
+
+ (void) dst;
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ // TODO: why do we pass dst as src1 here?
+ ggml_sycl_cpy(ctx, src0, dst, nullptr);
+ (void) src1;
+}
+
+static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_diag_mask_inf);
+}
+
+static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_soft_max);
+}
+
+static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rope);
+}
+
+static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pool2d);
+}
+
+static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
+}
+
+static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
+}
+
+static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
+}
+
+static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ (void) src0;
+ (void) src1;
+ (void) dst;
+}
+
+static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
+}
+
+void ggml_sycl_set_main_device(const int main_device) try {
+ if (dpct::get_current_device_id() == main_device) return;
+ check_allow_gpu_index(main_device);
+ dpct::select_device(main_device);
+
+ if (g_ggml_sycl_debug) {
+ dpct::device_info prop;
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
+ prop, dpct::dev_mgr::instance().get_device(main_device))));
+ fprintf(stderr, "Using device %d (%s) as main device\n",
+ main_device, prop.get_name());
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * tensor) {
+ if (!g_sycl_loaded) return false;
+
+ ggml_sycl_func_t func;
+
+ switch (tensor->op) {
+ case GGML_OP_REPEAT:
+ func = ggml_sycl_repeat;
+ break;
+ case GGML_OP_GET_ROWS:
+ func = ggml_sycl_get_rows;
+ break;
+ case GGML_OP_DUP:
+ func = ggml_sycl_dup;
+ break;
+ case GGML_OP_ADD:
+ func = ggml_sycl_add;
+ break;
+ case GGML_OP_ACC:
+ func = ggml_sycl_acc;
+ break;
+ case GGML_OP_MUL:
+ func = ggml_sycl_mul;
+ break;
+ case GGML_OP_DIV:
+ func = ggml_sycl_div;
+ break;
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(tensor)) {
+ case GGML_UNARY_OP_GELU:
+ func = ggml_sycl_gelu;
+ break;
+ case GGML_UNARY_OP_SILU:
+ func = ggml_sycl_silu;
+ break;
+ case GGML_UNARY_OP_GELU_QUICK:
+ func = ggml_sycl_gelu_quick;
+ break;
+ case GGML_UNARY_OP_TANH:
+ func = ggml_sycl_tanh;
+ break;
+ case GGML_UNARY_OP_RELU:
+ func = ggml_sycl_relu;
+ break;
+ case GGML_UNARY_OP_HARDSIGMOID:
+ func = ggml_sycl_hardsigmoid;
+ break;
+ case GGML_UNARY_OP_HARDSWISH:
+ func = ggml_sycl_hardswish;
+ break;
+ default:
+ return false;
+ }
+ break;
+ case GGML_OP_NORM:
+ func = ggml_sycl_norm;
+ break;
+ case GGML_OP_GROUP_NORM:
+ func = ggml_sycl_group_norm;
+ break;
+ case GGML_OP_CONCAT:
+ func = ggml_sycl_op_concat;
+ break;
+ case GGML_OP_UPSCALE:
+ func = ggml_sycl_upscale;
+ break;
+ case GGML_OP_PAD:
+ func = ggml_sycl_pad;
+ break;
+ case GGML_OP_LEAKY_RELU:
+ func = ggml_sycl_leaky_relu;
+ break;
+ case GGML_OP_RMS_NORM:
+ func = ggml_sycl_rms_norm;
+ break;
+ case GGML_OP_MUL_MAT:
+ if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
+ return false;
+ }
+ func = ggml_sycl_mul_mat;
+ break;
+ case GGML_OP_MUL_MAT_ID:
+ if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
+ return false;
+ }
+ func = ggml_sycl_mul_mat_id;
+ break;
+ case GGML_OP_SCALE:
+ func = ggml_sycl_scale;
+ break;
+ case GGML_OP_SQR:
+ func = ggml_sycl_sqr;
+ break;
+ case GGML_OP_CLAMP:
+ func = ggml_sycl_clamp;
+ break;
+ case GGML_OP_CPY:
+ func = ggml_sycl_cpy;
+ break;
+ case GGML_OP_CONT:
+ func = ggml_sycl_dup;
+ break;
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ func = ggml_sycl_nop;
+ break;
+ case GGML_OP_DIAG_MASK_INF:
+ func = ggml_sycl_diag_mask_inf;
+ break;
+ case GGML_OP_SOFT_MAX:
+ func = ggml_sycl_soft_max;
+ break;
+ case GGML_OP_ROPE:
+ func = ggml_sycl_rope;
+ break;
+ case GGML_OP_IM2COL:
+ func = ggml_sycl_im2col;
+ break;
+ case GGML_OP_POOL_2D:
+ func = ggml_sycl_pool2d;
+ break;
+ case GGML_OP_SUM_ROWS:
+ func = ggml_sycl_sum_rows;
+ break;
+ case GGML_OP_ARGSORT:
+ func = ggml_sycl_argsort;
+ break;
+ default:
+ return false;
+ }
+
+ if (tensor->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(tensor->src[0]->buffer)) {
+ ggml_sycl_set_peer_access(tensor->src[1]->ne[1], ctx.device);
+ }
+
+ func(ctx, tensor->src[0], tensor->src[1], tensor);
+ return true;
+}
+
+GGML_API GGML_CALL void ggml_sycl_get_gpu_list(int *id_list, int max_len) try {
+ GGML_SYCL_DEBUG("[SYCL] call ggml_sycl_get_gpu_list\n");
+ for(int i=0;i<max_len;i++) id_list[i] = -1;
+
+ for (int i=0;i< ggml_sycl_info().device_count;i++){
+ if (i>=max_len) break;
+ id_list[i] = i;
+ }
+ return;
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+int ggml_sycl_get_device_count() try {
+ int device_count;
+ if (CHECK_TRY_ERROR(device_count =
+ dpct::dev_mgr::instance().device_count()) != 0) {
+ return 0;
+ }
+ return device_count;
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+GGML_API GGML_CALL void ggml_sycl_get_device_description(int device, char *description,
+ size_t description_size) try {
+ GGML_SYCL_DEBUG("[SYCL] call ggml_sycl_get_device_description\n");
+ dpct::device_info prop;
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
+ prop, dpct::dev_mgr::instance().get_device(device))));
+ snprintf(description, description_size, "%s", prop.get_name());
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free,
+ size_t *total) try {
+ GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n");
+ ggml_sycl_set_device(device);
+
+ /*
+ DPCT1009:218: SYCL uses exceptions to report errors and does not use the
+ error codes. The original code was commented out and a warning string was
+ inserted. You need to rewrite this code.
+ */
+ /*
+ DPCT1106:217: 'cudaMemGetInfo' was migrated with the Intel extensions for
+ device information which may not be supported by all compilers or runtimes.
+ You may need to adjust the code.
+ */
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total)));
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// backend interface
+
+#define UNUSED GGML_UNUSED
+
+// sycl buffer
+
+struct ggml_backend_sycl_buffer_context {
+ int device;
+ void * dev_ptr = nullptr;
+ queue_ptr stream;
+ std::string name;
+
+ ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
+ device(device), dev_ptr(dev_ptr), stream(stream) {
+ check_allow_gpu_index(device);
+ name = (GGML_SYCL_NAME + std::to_string(device));
+ }
+
+
+ ~ggml_backend_sycl_buffer_context() {
+ if (dev_ptr != nullptr) {
+ ggml_sycl_set_device(device);
+ SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));
+ }
+ }
+};
+
+GGML_CALL static const char * ggml_backend_sycl_buffer_get_name(ggml_backend_buffer_t buffer) {
+ ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
+ return ctx->name.c_str();
+}
+
+GGML_CALL static bool ggml_backend_buffer_is_sycl(ggml_backend_buffer_t buffer) {
+ return buffer->iface.get_name == ggml_backend_sycl_buffer_get_name;
+}
+
+static void
+ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
+ ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
+ ggml_sycl_set_device(ctx->device);
+
+ delete ctx;
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) {
+ ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
+ return ctx->dev_ptr;
+}
+
+GGML_CALL static void
+ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
+ ggml_tensor *tensor) try {
+ ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
+
+ if (tensor->view_src != NULL && tensor->view_offs == 0) {
+ assert(tensor->view_src->buffer->buft == buffer->buft);
+ tensor->backend = tensor->view_src->backend;
+ tensor->extra = tensor->view_src->extra;
+ return;
+ }
+
+
+ if (ggml_is_quantized(tensor->type)) {
+ // initialize padding to 0 to avoid possible NaN values
+ size_t original_size = ggml_nbytes(tensor);
+ size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
+
+ if (padded_size > original_size && tensor->view_src == nullptr) {
+ SYCL_CHECK(CHECK_TRY_ERROR(ctx->stream->memset(
+ (char *)tensor->data + original_size, 0,
+ padded_size - original_size).wait()));
+ }
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
+ ggml_tensor *tensor,
+ const void *data, size_t offset,
+ size_t size) try {
+
+ ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
+
+ ggml_sycl_set_device(ctx->device);
+ auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
+ char* host_buf = (char*)malloc(size);
+ memcpy(host_buf, data, size);
+ SYCL_CHECK(
+ CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size)
+ .wait()));
+ free(host_buf);
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer,
+ const ggml_tensor *tensor,
+ void *data, size_t offset,
+ size_t size) try {
+
+ ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
+
+ ggml_sycl_set_device(ctx->device);
+ auto stream = dpct::dev_mgr::instance().get_device(ctx->device).default_queue();
+
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ stream.memcpy(data, (const char *)tensor->data + offset, size)
+ .wait()));
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+GGML_CALL static bool
+ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
+ const ggml_tensor *src,
+ ggml_tensor *dst) try {
+ if (ggml_backend_buffer_is_sycl(src->buffer)) {
+ ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context;
+ ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context;
+
+ ggml_sycl_set_device(src_ctx->device);
+ /*
+ DPCT1009:198: SYCL uses exceptions to report errors and does not use the
+ error codes. The original code was commented out and a warning string
+ was inserted. You need to rewrite this code.
+ */
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ dpct::dev_mgr::instance().get_device(src_ctx->device).queues_wait_and_throw()));
+ ggml_sycl_set_device(dst_ctx->device);
+ /*
+ DPCT1009:199: SYCL uses exceptions to report errors and does not use the
+ error codes. The original code was commented out and a warning string
+ was inserted. You need to rewrite this code.
+ */
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));
+ /*
+ DPCT1009:200: SYCL uses exceptions to report errors and does not use the
+ error codes. The original code was commented out and a warning string
+ was inserted. You need to rewrite this code.
+ */
+
+ queue_ptr stream_dst = dst_ctx->stream;
+ queue_ptr stream_src = src_ctx->stream;
+ size_t size = ggml_nbytes(src);
+
+ //todo. it's dirty solutino to walkaroud known issue:device2device cross GPUs.
+ dev2dev_memcpy(*stream_dst, *stream_src, dst->data, src->data, size);
+
+//todo, it's known issue:error in device2device cross GPUs. reused when the issue is fixed. DON"T remove
+#if 0
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(
+ (char *)dst->data, (const char *)src->data, size).wait()));
+
+ /*
+ DPCT1009:201: SYCL uses exceptions to report errors and does not use the
+ error codes. The original code was commented out and a warning string
+ was inserted. You need to rewrite this code.
+ */
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));
+#endif
+ return true;
+ }
+ return false;
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+
+static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer,
+ uint8_t value) try {
+ ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
+
+ ggml_sycl_set_device(ctx->device);
+ queue_ptr stream = ctx->stream;
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw()));
+
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream)
+ .memset(ctx->dev_ptr, value, buffer->size)
+ .wait()));
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static struct ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
+ /* .get_name = */ ggml_backend_sycl_buffer_get_name,
+ /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_sycl_buffer_get_base,
+ /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
+ /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
+ /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
+ /* .clear = */ ggml_backend_sycl_buffer_clear,
+ /* .reset = */ NULL,
+};
+
+// sycl buffer type
+struct ggml_backend_sycl_buffer_type_context {
+ int device;
+ std::string name;
+
+ // each buffer type has its own stream
+ queue_ptr stream = nullptr;
+};
+
+GGML_CALL static const char * ggml_backend_sycl_buffer_type_name(ggml_backend_buffer_type_t buft) {
+ ggml_backend_sycl_buffer_type_context * ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
+
+ return ctx->name.c_str();
+}
+GGML_CALL static ggml_backend_buffer_t
+ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
+ size_t size) try {
+ ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
+ ggml_sycl_set_device(buft_ctx->device);
+ const queue_ptr stream = buft_ctx->stream;
+ size = std::max(size, (size_t)1); // syclMalloc returns null for size 0
+
+ void * dev_ptr;
+ SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device(
+ size, *stream)));
+ ggml_backend_sycl_buffer_context * ctx = new ggml_backend_sycl_buffer_context(buft_ctx->device, dev_ptr, buft_ctx->stream);
+ return ggml_backend_buffer_init(buft, ggml_backend_sycl_buffer_interface, ctx, size);
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+GGML_CALL static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+ return 128;
+ UNUSED(buft);
+}
+
+static size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
+ return dpct::get_current_device().get_max_mem_alloc_size();
+
+ UNUSED(buft);
+}
+
+GGML_CALL static size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+ size_t size = ggml_nbytes(tensor);
+ int64_t ne0 = tensor->ne[0];
+
+ if (ggml_is_quantized(tensor->type)) {
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+ }
+
+ return size;
+
+ UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
+ /* .get_name = */ ggml_backend_sycl_buffer_type_name,
+ /* .alloc_buffer = */ ggml_backend_sycl_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_sycl_buffer_type_get_alignment,
+ /* .get_max_size = */ ggml_backend_sycl_buffer_type_get_max_size,
+ /* .get_alloc_size = */ ggml_backend_sycl_buffer_type_get_alloc_size,
+ /* .is_host = */ nullptr,
+};
+
+ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+
+ GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
+
+ if (device>=ggml_sycl_info().device_count or device<0) {
+ printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
+ device, ggml_sycl_info().device_count-1);
+ GGML_ASSERT(device<ggml_sycl_info().device_count);
+ }
+ static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
+
+ static bool ggml_backend_sycl_buffer_type_initialized = false;
+
+ if (!ggml_backend_sycl_buffer_type_initialized) {
+ for (int i = 0; i < ggml_sycl_info().device_count; i++) {
+ auto & device_i = dpct::dev_mgr::instance().get_device(i);
+ queue_ptr stream = &(device_i.default_queue());
+ ggml_backend_sycl_buffer_types[i] = {
+ /* .iface = */ ggml_backend_sycl_buffer_type_interface,
+ /* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), stream},
+ };
+ }
+ ggml_backend_sycl_buffer_type_initialized = true;
+ }
+ return &ggml_backend_sycl_buffer_types[device];
+}
+
+ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
+ GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
+
+ int device = ctx->device;
+ if (device>=ggml_sycl_info().device_count or device<0) {
+ printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
+ device, ggml_sycl_info().device_count-1);
+ GGML_ASSERT(device<ggml_sycl_info().device_count);
+ }
+ static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_types[GGML_SYCL_MAX_DEVICES];
+
+ static bool ggml_backend_sycl_buffer_type_initialized = false;
+
+ if (!ggml_backend_sycl_buffer_type_initialized) {
+ for (int i = 0; i < ggml_sycl_info().device_count; i++) {
+ ggml_backend_sycl_buffer_types[i] = {
+ /* .iface = */ ggml_backend_sycl_buffer_type_interface,
+ /* .context = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), ctx->stream(i, 0)},
+ };
+ }
+ ggml_backend_sycl_buffer_type_initialized = true;
+ }
+ return &ggml_backend_sycl_buffer_types[device];
+}
+
+// sycl split buffer type
+static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_SYCL_MAX_DEVICES> & tensor_split, int id) {
+ const int64_t nrows = ggml_nrows(tensor);
+ const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
+
+ *row_low = id == 0 ? 0 : nrows*tensor_split[id];
+ *row_low -= *row_low % rounding;
+ if (id == ggml_sycl_info().device_count - 1) {
+ *row_high = nrows;
+ } else {
+ *row_high = nrows*tensor_split[id + 1];
+ *row_high -= *row_high % rounding;
+ }
+}
+
+struct ggml_backend_sycl_split_buffer_context {
+ ~ggml_backend_sycl_split_buffer_context() try {
+ for (ggml_tensor_extra_gpu * extra : tensor_extras) {
+ for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+ for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
+ if (extra->events[i][is] != nullptr) {
+ /*
+ DPCT1009:206: SYCL uses exceptions to report errors and
+ does not use the error codes. The original code was
+ commented out and a warning string was inserted. You
+ need to rewrite this code.
+ */
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ dpct::destroy_event(extra->events[i][is])));
+ }
+ }
+ if (extra->data_device[i] != nullptr) {
+ /*
+ DPCT1009:207: SYCL uses exceptions to report errors and does
+ not use the error codes. The original code was commented out
+ and a warning string was inserted. You need to rewrite this
+ code.
+ */
+ ggml_sycl_set_device(i);
+ SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(
+ extra->data_device[i], *(streams[i]))));
+ }
+ }
+ delete extra;
+ }
+ }
+ catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+ }
+
+ std::vector<ggml_tensor_extra_gpu *> tensor_extras;
+ std::vector<queue_ptr> streams;
+};
+
+GGML_CALL static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backend_buffer_t buffer) {
+ return GGML_SYCL_NAME "_Split";
+
+ UNUSED(buffer);
+}
+
+static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
+ return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
+}
+
+GGML_CALL static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
+ delete ctx;
+}
+
+GGML_CALL static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
+ // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
+ return (void *)0x1000;
+
+ UNUSED(buffer);
+}
+
+GGML_CALL static void
+ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
+ ggml_tensor *tensor) try {
+ GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
+
+ ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
+ ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
+
+ const int64_t ne0 = tensor->ne[0];
+
+ ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
+
+ ctx->tensor_extras.push_back(extra);
+ ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
+
+ for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+ int64_t row_low, row_high;
+ get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
+
+ int64_t nrows_split = row_high - row_low;
+ if (nrows_split == 0) {
+ continue;
+ }
+
+ size_t size = ggml_nbytes_split(tensor, nrows_split);
+ const size_t original_size = size;
+
+ // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+
+ // FIXME: do not crash if cudaMalloc fails
+ // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
+ ggml_sycl_set_device(i);
+ const queue_ptr stream = ctx->streams[i];
+ char * buf;
+ /*
+ DPCT1009:208: SYCL uses exceptions to report errors and does not use the
+ error codes. The original code was commented out and a warning string
+ was inserted. You need to rewrite this code.
+ */
+ SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device(
+ size, *stream)));
+
+ // set padding to 0 to avoid possible NaN values
+ if (size > original_size) {
+ /*
+ DPCT1009:209: SYCL uses exceptions to report errors and does not use
+ the error codes. The original code was commented out and a warning
+ string was inserted. You need to rewrite this code.
+ */
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ (*stream)
+ .memset(buf + original_size, 0, size - original_size)
+ .wait()));
+ }
+
+ extra->data_device[i] = buf;
+
+ for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
+ /*
+ DPCT1009:210: SYCL uses exceptions to report errors and does not use
+ the error codes. The original code was commented out and a warning
+ string was inserted. You need to rewrite this code.
+ */
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event()));
+ }
+ }
+ tensor->backend = GGML_BACKEND_TYPE_GPU_SPLIT;
+ tensor->extra = extra;
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+GGML_CALL static void
+ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer,
+ ggml_tensor *tensor, const void *data,
+ size_t offset, size_t size) try {
+ // split tensors must always be set in their entirety at once
+ GGML_ASSERT(offset == 0);
+ GGML_ASSERT(size == ggml_nbytes(tensor));
+
+ ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
+ ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
+
+ const int64_t ne0 = tensor->ne[0];
+ const size_t nb1 = tensor->nb[1];
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
+
+ for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+ int64_t row_low, row_high;
+ get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
+
+ int64_t nrows_split = row_high - row_low;
+ if (nrows_split == 0) {
+ continue;
+ }
+
+ const size_t offset_split = row_low*nb1;
+ size_t size = ggml_nbytes_split(tensor, nrows_split);
+ const size_t original_size = size;
+
+ // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+
+ const char * buf_host = (const char *)data + offset_split;
+ /*
+ DPCT1009:211: SYCL uses exceptions to report errors and does not use the
+ error codes. The original code was commented out and a warning string
+ was inserted. You need to rewrite this code.
+ */
+ ggml_sycl_set_device(i);
+ const queue_ptr stream = ctx->streams[i];
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ (*stream)
+ .memcpy(extra->data_device[i], buf_host, original_size)
+ .wait()));
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+GGML_CALL static void
+ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer,
+ const ggml_tensor *tensor, void *data,
+ size_t offset, size_t size) try {
+ // split tensors must always be set in their entirety at once
+ GGML_ASSERT(offset == 0);
+ GGML_ASSERT(size == ggml_nbytes(tensor));
+
+ ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
+ ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
+
+ const int64_t ne0 = tensor->ne[0];
+ const size_t nb1 = tensor->nb[1];
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
+
+ for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+ int64_t row_low, row_high;
+ get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
+
+ int64_t nrows_split = row_high - row_low;
+ if (nrows_split == 0) {
+ continue;
+ }
+
+ const size_t offset_split = row_low*nb1;
+ size_t size = ggml_nbytes_split(tensor, nrows_split);
+ const size_t original_size = size;
+
+ // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+
+ char * buf_host = (char *)data + offset_split;
+ /*
+ DPCT1009:212: SYCL uses exceptions to report errors and does not use the
+ error codes. The original code was commented out and a warning string
+ was inserted. You need to rewrite this code.
+ */
+ ggml_sycl_set_device(i);
+ const queue_ptr stream = ctx->streams[i];
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ (*stream)
+ .memcpy(buf_host, extra->data_device[i], original_size)
+ .wait()));
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+GGML_CALL static void ggml_backend_sycl_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ UNUSED(buffer);
+ UNUSED(value);
+}
+
+static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = {
+ /* .get_name = */ ggml_backend_sycl_split_buffer_get_name,
+ /* .free_buffer = */ ggml_backend_sycl_split_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_sycl_split_buffer_get_base,
+ /* .init_tensor = */ ggml_backend_sycl_split_buffer_init_tensor,
+ /* .set_tensor = */ ggml_backend_sycl_split_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_sycl_split_buffer_get_tensor,
+ /* .cpy_tensor = */ NULL,
+ /* .clear = */ ggml_backend_sycl_split_buffer_clear,
+ /* .reset = */ NULL,
+};
+
+GGML_CALL static const char * ggml_backend_sycl_split_buffer_type_name(ggml_backend_buffer_type_t buft) {
+ return GGML_SYCL_NAME "_Split";
+
+ UNUSED(buft);
+}
+
+GGML_CALL static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
+ // instead, we allocate them for each tensor separately in init_tensor
+ // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
+ // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
+ ggml_backend_sycl_split_buffer_context * ctx = new ggml_backend_sycl_split_buffer_context();
+
+ return ggml_backend_buffer_init(buft, ggml_backend_sycl_split_buffer_interface, ctx, size);
+}
+
+GGML_CALL static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+ return 128;
+ UNUSED(buft);
+}
+
+GGML_CALL static size_t ggml_backend_sycl_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+ ggml_backend_sycl_split_buffer_type_context * ctx = (ggml_backend_sycl_split_buffer_type_context *)buft->context;
+
+ size_t total_size = 0;
+
+ const int64_t ne0 = tensor->ne[0];
+
+ for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+ int64_t row_low, row_high;
+ get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, i);
+
+ int64_t nrows_split = row_high - row_low;
+ if (nrows_split == 0) {
+ continue;
+ }
+
+ total_size += ggml_nbytes_split(tensor, nrows_split);
+
+ // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+ }
+
+ return total_size;
+}
+
+GGML_CALL static bool ggml_backend_sycl_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+ return false;
+
+ UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface = {
+ /* .get_name = */ ggml_backend_sycl_split_buffer_type_name,
+ /* .alloc_buffer = */ ggml_backend_sycl_split_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_sycl_split_buffer_type_get_alignment,
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
+ /* .get_alloc_size = */ ggml_backend_sycl_split_buffer_type_get_alloc_size,
+ /* .is_host = */ ggml_backend_sycl_split_buffer_type_is_host,
+};
+
+GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+
+ GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n");
+ ggml_check_sycl();
+ // FIXME: this is not thread safe
+ static std::map<std::array<float, GGML_SYCL_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
+
+ std::array<float, GGML_SYCL_MAX_DEVICES> tensor_split_arr = {};
+
+ bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_SYCL_MAX_DEVICES, [](float x) { return x == 0.0f; });
+ if (all_zero) {
+ tensor_split_arr = ggml_sycl_info().default_tensor_split;
+ } else {
+ float split_sum = 0.0f;
+ for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+ tensor_split_arr[i] = split_sum;
+ split_sum += tensor_split[i];
+ }
+ for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+ tensor_split_arr[i] /= split_sum;
+ }
+ }
+
+ auto it = buft_map.find(tensor_split_arr);
+ if (it != buft_map.end()) {
+ return &it->second;
+ }
+
+ struct ggml_backend_buffer_type buft {
+ /* .iface = */ ggml_backend_sycl_split_buffer_type_interface,
+ /* .context = */ new ggml_backend_sycl_split_buffer_type_context{tensor_split_arr},
+ };
+
+ auto result = buft_map.emplace(tensor_split_arr, buft);
+ return &result.first->second;
+}
+
+// host buffer type
+
+GGML_CALL static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
+ return GGML_SYCL_NAME "_Host";
+
+ UNUSED(buft);
+}
+
+GGML_CALL static const char * ggml_backend_sycl_host_buffer_name(ggml_backend_buffer_t buffer) {
+ return GGML_SYCL_NAME "_Host";
+
+ UNUSED(buffer);
+}
+
+static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ ggml_sycl_host_free(buffer->context);
+}
+
+static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ void * ptr = ggml_sycl_host_malloc(size);
+
+ if (ptr == nullptr) {
+ // fallback to cpu buffer
+ return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
+ }
+
+ // FIXME: this is a hack to avoid having to implement a new buffer type
+ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
+ buffer->buft = buft;
+ buffer->iface.get_name = ggml_backend_sycl_host_buffer_name;
+ buffer->iface.free_buffer = ggml_backend_sycl_host_buffer_free_buffer;
+
+ return buffer;
+}
+
+ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type() {
+ GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_host_buffer_type\n");
+ static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_type_host = {
+ /* .iface = */ {
+ /* .get_name = */ ggml_backend_sycl_host_buffer_type_name,
+ /* .alloc_buffer = */ ggml_backend_sycl_host_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
+ /* .get_max_size = */ NULL, // TODO: return device.maxBufferLength
+ /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
+ /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
+ },
+ /* .context = */ nullptr,
+ };
+
+ return &ggml_backend_sycl_buffer_type_host;
+}
+
+// backend
+
+GGML_CALL static const char * ggml_backend_sycl_name(ggml_backend_t backend) {
+
+ ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
+
+ return sycl_ctx->name.c_str();
+}
+
+GGML_CALL static void ggml_backend_sycl_free(ggml_backend_t backend) {
+ ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
+
+ delete sycl_ctx;
+ delete backend;
+}
+
+
+GGML_CALL static ggml_backend_buffer_type_t ggml_backend_sycl_get_default_buffer_type(ggml_backend_t backend) {
+ ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
+ return ggml_backend_sycl_buffer_type(sycl_ctx->device);
+}
+
+GGML_CALL static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
+ ggml_tensor *tensor,
+ const void *data, size_t offset,
+ size_t size) try {
+ ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+ GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
+ const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
+ SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
+ (char *)tensor->data + offset, data, size).wait()));
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+GGML_CALL static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
+ const ggml_tensor *tensor,
+ void *data, size_t offset,
+ size_t size) try {
+ ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+ GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
+ const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
+ SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
+ data, (const char *)tensor->data + offset, size).wait()));
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+GGML_CALL static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
+ const ggml_tensor *src,
+ ggml_tensor *dst) try {
+ ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
+ if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && ggml_backend_buffer_is_sycl(src->buffer)) {
+ /*
+ DPCT1009:215: SYCL uses exceptions to report errors and does not use the
+ error codes. The original code was commented out and a warning string
+ was inserted. You need to rewrite this code.
+ */
+ const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
+ SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
+ dst->data, src->data, ggml_nbytes(dst)).wait()));
+ return true;
+ }
+
+ return false;
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try {
+ ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
+ const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
+ SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait()));
+
+ UNUSED(backend);
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+ ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
+ ggml_sycl_set_main_device(sycl_ctx->device);
+
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ ggml_tensor * node = cgraph->nodes[i];
+ if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+ continue;
+ }
+#ifndef NDEBUG
+ assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ if (node->src[j] != nullptr) {
+ assert(node->src[j]->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
+ }
+ }
+#endif
+ bool ok = ggml_sycl_compute_forward(*sycl_ctx, node);
+ if (!ok) {
+ fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+ }
+ GGML_ASSERT(ok);
+ }
+
+ return GGML_STATUS_SUCCESS;
+}
+
+GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
+ switch (op->op) {
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(op)) {
+ case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_HARDSIGMOID:
+ case GGML_UNARY_OP_HARDSWISH:
+ case GGML_UNARY_OP_GELU_QUICK:
+ case GGML_UNARY_OP_TANH:
+ return ggml_is_contiguous(op->src[0]);
+ default:
+ return false;
+ }
+ break;
+ case GGML_OP_MUL_MAT:
+ case GGML_OP_MUL_MAT_ID:
+ {
+ struct ggml_tensor * a;
+ struct ggml_tensor * b;
+ if (op->op == GGML_OP_MUL_MAT) {
+ a = op->src[0];
+ b = op->src[1];
+ } else {
+ a = op->src[2];
+ b = op->src[1];
+ }
+ if (a->ne[3] != b->ne[3]) {
+ return false;
+ }
+ ggml_type a_type = a->type;
+ if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
+ a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
+ a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S ||
+ a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M
+ ) {
+ if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
+ return false;
+ }
+ }
+ ggml_type src0_type = op->src[0]->type;
+ if (src0_type == GGML_TYPE_BF16) {
+ return false;
+ }
+ return true;
+ } break;
+ case GGML_OP_GET_ROWS:
+ {
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ return true;
+ default:
+ return false;
+ }
+ } break;
+ case GGML_OP_CPY:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ ggml_type src1_type = op->src[1]->type;
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ return false;
+ } break;
+ case GGML_OP_CONCAT:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ int dim = op->op_params[0];
+ return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2;
+ } break;
+ case GGML_OP_DUP:
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_REPEAT:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_NORM:
+ case GGML_OP_ADD:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_SCALE:
+ case GGML_OP_SQR:
+ case GGML_OP_CLAMP:
+ case GGML_OP_CONT:
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_SOFT_MAX:
+ return true;
+ case GGML_OP_ROPE:
+ return ggml_is_contiguous(op->src[0]);
+ case GGML_OP_IM2COL:
+ case GGML_OP_POOL_2D:
+ case GGML_OP_SUM_ROWS:
+ case GGML_OP_ARGSORT:
+ case GGML_OP_ACC:
+ case GGML_OP_GROUP_NORM:
+ case GGML_OP_UPSCALE:
+ case GGML_OP_PAD:
+ case GGML_OP_LEAKY_RELU:
+ return true;
+ default:
+ return false;
+ }
+
+ UNUSED(backend);
+}
+
+GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
+ const int min_batch_size = 32;
+ return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
+ GGML_UNUSED(backend);
+}
+
+GGML_CALL static bool ggml_backend_sycl_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+ if (buft->iface.get_name != ggml_backend_sycl_buffer_type_name) {
+ return false;
+ }
+ ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
+ ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
+ return buft_ctx->device == sycl_ctx->device;
+}
+
+static ggml_backend_i ggml_backend_sycl_interface = {
+ /* .get_name = */ ggml_backend_sycl_name,
+ /* .free = */ ggml_backend_sycl_free,
+ /* .get_default_buffer_type = */ ggml_backend_sycl_get_default_buffer_type,
+ /* .set_tensor_async = */ ggml_backend_sycl_set_tensor_async,
+ /* .get_tensor_async = */ ggml_backend_sycl_get_tensor_async,
+ /* .cpy_tensor_async = */ NULL, //ggml_backend_sycl_cpy_tensor_async, // TODO: update for the new interface
+ /* .synchronize = */ ggml_backend_sycl_synchronize,
+ /* .graph_plan_create = */ NULL,
+ /* .graph_plan_free = */ NULL,
+ /* .graph_plan_update = */ NULL,
+ /* .graph_plan_compute = */ NULL,
+ /* .graph_compute = */ ggml_backend_sycl_graph_compute,
+ /* .supports_op = */ ggml_backend_sycl_supports_op,
+ /* .supports_buft = */ ggml_backend_sycl_supports_buft,
+ /* .offload_op = */ ggml_backend_sycl_offload_op,
+ /* .event_new = */ NULL,
+ /* .event_free = */ NULL,
+ /* .event_record = */ NULL,
+ /* .event_wait = */ NULL,
+ /* .event_synchronize = */ NULL,
+};
+
+static ggml_guid_t ggml_backend_sycl_guid() {
+ static ggml_guid guid = { 0x58, 0x05, 0x13, 0x8f, 0xcd, 0x3a, 0x61, 0x9d, 0xe7, 0xcd, 0x98, 0xa9, 0x03, 0xfd, 0x7c, 0x53 };
+ return &guid;
+}
+
+GGML_CALL ggml_backend_t ggml_backend_sycl_init(int device) {
+ GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_init\n");
+ ggml_check_sycl();
+
+ check_allow_gpu_index(device);
+
+ ggml_backend_sycl_context * ctx = new ggml_backend_sycl_context(device);
+ if (ctx == nullptr) {
+ fprintf(stderr, "%s: error: failed to allocate context\n", __func__);
+ return nullptr;
+ };
+
+ ggml_backend_t sycl_backend = new ggml_backend {
+ /* .guid = */ ggml_backend_sycl_guid(),
+ /* .interface = */ ggml_backend_sycl_interface,
+ /* .context = */ ctx
+ };
+
+ return sycl_backend;
+}
+
+bool ggml_backend_is_sycl(ggml_backend_t backend) {
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_sycl_guid());
+}
+
+GGML_CALL int ggml_backend_sycl_get_device_count() {
+ GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
+ return ggml_sycl_info().device_count;
+}
+
+GGML_CALL static ggml_backend_t ggml_backend_reg_sycl_init(const char * params, void * user_data) {
+ ggml_backend_t sycl_backend = ggml_backend_sycl_init((int) (intptr_t) user_data);
+ return sycl_backend;
+
+ UNUSED(params);
+}
+
+extern "C" int ggml_backend_sycl_reg_devices();
+
+int ggml_backend_sycl_reg_devices() {
+ assert(ggml_sycl_info().device_count>0);
+ for (int i = 0; i < ggml_sycl_info().device_count; i++) {
+ char name[128];
+ snprintf(name, sizeof(name), "%s%d", GGML_SYCL_NAME, i);
+ ggml_backend_register(name, ggml_backend_reg_sycl_init, ggml_backend_sycl_buffer_type(i), (void *) (intptr_t) i);
+ }
+ return ggml_sycl_info().device_count;
+}
diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp
new file mode 100644
index 00000000..067181de
--- /dev/null
+++ b/ggml/src/ggml-sycl/backend.hpp
@@ -0,0 +1,27 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_BACKEND_HPP
+#define GGML_SYCL_BACKEND_HPP
+
+#include "concat.hpp"
+#include "common.hpp"
+#include "convert.hpp"
+#include "dequantize.hpp"
+#include "dmmv.hpp"
+#include "mmq.hpp"
+#include "mmvq.hpp"
+#include "rope.hpp"
+#include "norm.hpp"
+#include "softmax.hpp"
+
+#endif // GGML_SYCL_BACKEND_HPP
diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp
new file mode 100644
index 00000000..e878f4f5
--- /dev/null
+++ b/ggml/src/ggml-sycl/common.cpp
@@ -0,0 +1,53 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#include "common.hpp"
+
+int get_current_device_id() {
+ return dpct::dev_mgr::instance().current_device_id();
+}
+
+void* ggml_sycl_host_malloc(size_t size) try {
+ if (getenv("GGML_SYCL_NO_PINNED") != nullptr) {
+ return nullptr;
+ }
+
+ void* ptr = nullptr;
+ // allow to use dpct::get_in_order_queue() for host malloc
+ dpct::err0 err = CHECK_TRY_ERROR(
+ ptr = (void*)sycl::malloc_host(size, dpct::get_in_order_queue()));
+
+ if (err != 0) {
+ // clear the error
+ fprintf(
+ stderr,
+ "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
+ size / 1024.0 / 1024.0,
+ "syclGetErrorString is not supported");
+ return nullptr;
+ }
+
+ return ptr;
+} catch (sycl::exception const& exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+void ggml_sycl_host_free(void* ptr) try {
+ // allow to use dpct::get_in_order_queue() for host malloc
+ SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, dpct::get_in_order_queue())));
+} catch (sycl::exception const& exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp
new file mode 100644
index 00000000..397bd98d
--- /dev/null
+++ b/ggml/src/ggml-sycl/common.hpp
@@ -0,0 +1,355 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_COMMON_HPP
+#define GGML_SYCL_COMMON_HPP
+
+#include <fstream>
+#include <iostream>
+
+#include "dpct/helper.hpp"
+#include "ggml-sycl.h"
+#include "presets.hpp"
+
+#define GGML_COMMON_DECL_SYCL
+#define GGML_COMMON_IMPL_SYCL
+#include "ggml-common.h"
+
+void* ggml_sycl_host_malloc(size_t size);
+void ggml_sycl_host_free(void* ptr);
+
+static int g_ggml_sycl_debug = 0;
+#define GGML_SYCL_DEBUG(...) \
+ do { \
+ if (g_ggml_sycl_debug) \
+ fprintf(stderr, __VA_ARGS__); \
+ } while (0)
+
+#define CHECK_TRY_ERROR(expr) \
+ [&]() { \
+ try { \
+ expr; \
+ return dpct::success; \
+ } catch (std::exception const& e) { \
+ std::cerr << e.what() << "\nException caught at file:" << __FILE__ \
+ << ", line:" << __LINE__ << ", func:" << __func__ \
+ << std::endl; \
+ return dpct::default_error; \
+ } \
+ }()
+
+
+#define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
+#define VER_4VEC 610 // todo for hardward optimize.
+#define VER_GEN9 700 // todo for hardward optimize.
+#define VER_GEN12 1000000 // todo for hardward optimize.
+#define VER_GEN13 (VER_GEN12 + 1030) // todo for hardward optimize.
+
+#define GGML_SYCL_MAX_NODES 8192 // TODO: adapt to hardwares
+
+// define for XMX in Intel GPU
+// TODO: currently, it's not used for XMX really.
+#if !defined(GGML_SYCL_FORCE_MMQ)
+ #define SYCL_USE_XMX
+#endif
+
+// max batch size to use MMQ kernels when tensor cores are available
+#define MMQ_MAX_BATCH_SIZE 32
+
+#if defined(_MSC_VER)
+#pragma warning(disable : 4244 4267) // possible loss of data
+#endif
+
+// dmmv = dequantize_mul_mat_vec
+#ifndef GGML_SYCL_DMMV_X
+#define GGML_SYCL_DMMV_X 32
+#endif
+#ifndef GGML_SYCL_MMV_Y
+#define GGML_SYCL_MMV_Y 1
+#endif
+
+typedef sycl::queue *queue_ptr;
+
+enum ggml_sycl_backend_gpu_mode {
+ SYCL_UNSET_GPU_MODE = -1,
+ SYCL_SINGLE_GPU_MODE = 0,
+ SYCL_MUL_GPU_MODE
+};
+
+static_assert(sizeof(sycl::half) == sizeof(ggml_fp16_t), "wrong fp16 size");
+
+static void crash() {
+ int* ptr = NULL;
+ *ptr = 0;
+}
+
+[[noreturn]] static void ggml_sycl_error(
+ const char* stmt,
+ const char* func,
+ const char* file,
+ const int line,
+ const char* msg) {
+ fprintf(stderr, "SYCL error: %s: %s\n", stmt, msg);
+ fprintf(stderr, " in function %s at %s:%d\n", func, file, line);
+ GGML_ASSERT(!"SYCL error");
+}
+
+#define SYCL_CHECK(err) \
+ do { \
+ auto err_ = (err); \
+ if (err_ != 0) \
+ ggml_sycl_error( \
+ #err, \
+ __func__, \
+ __FILE__, \
+ __LINE__, \
+ "Meet error in this line code!"); \
+ } while (0)
+
+#if DPCT_COMPAT_RT_VERSION >= 11100
+#define GGML_SYCL_ASSUME(x) __builtin_assume(x)
+#else
+#define GGML_SYCL_ASSUME(x)
+#endif // DPCT_COMPAT_RT_VERSION >= 11100
+
+#ifdef GGML_SYCL_F16
+typedef sycl::half dfloat; // dequantize float
+typedef sycl::half2 dfloat2;
+#else
+typedef float dfloat; // dequantize float
+typedef sycl::float2 dfloat2;
+#endif // GGML_SYCL_F16
+
+#define MMVQ_MAX_BATCH_SIZE 8
+
+static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+
+static int g_all_sycl_device_count = -1;
+static bool g_ggml_backend_sycl_buffer_type_initialized = false;
+
+static ggml_sycl_backend_gpu_mode g_ggml_sycl_backend_gpu_mode =
+ SYCL_UNSET_GPU_MODE;
+
+static void* g_scratch_buffer = nullptr;
+static size_t g_scratch_size = 0; // disabled by default
+static size_t g_scratch_offset = 0;
+
+[[noreturn]] static inline void bad_arch(const sycl::stream& stream_ct1) {
+ stream_ct1 << "ERROR: ggml-sycl was compiled without support for the "
+ "current GPU architecture.\n";
+ // __trap();
+ std::exit(1);
+
+ (void)bad_arch; // suppress unused function warning
+}
+
+int get_current_device_id();
+
+inline dpct::err0 ggml_sycl_set_device(const int device) try {
+
+ int current_device_id;
+ SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
+
+ // GGML_SYCL_DEBUG("ggml_sycl_set_device device_id=%d,
+ // current_device_id=%d\n", device, current_device);
+ if (device == current_device_id) {
+ return 0;
+ }
+
+ return CHECK_TRY_ERROR(dpct::select_device(device));
+} catch (sycl::exception const& exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ crash();
+ std::exit(1);
+}
+
+//////////////////////
+
+struct ggml_sycl_device_info {
+ int device_count;
+
+ struct sycl_device_info {
+ int cc; // compute capability
+ // int nsm; // number of streaming multiprocessors
+ // size_t smpb; // max. shared memory per block
+ bool vmm; // virtual memory support
+ size_t total_vram;
+ };
+
+ sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {};
+
+ std::array<float, GGML_SYCL_MAX_DEVICES> default_tensor_split = {};
+
+ int max_work_group_sizes[GGML_SYCL_MAX_DEVICES] = {0};
+};
+
+const ggml_sycl_device_info & ggml_sycl_info();
+
+struct ggml_sycl_pool {
+ virtual ~ggml_sycl_pool() = default;
+
+ virtual void * alloc(size_t size, size_t * actual_size) = 0;
+ virtual void free(void * ptr, size_t size) = 0;
+};
+
+template<typename T>
+struct ggml_sycl_pool_alloc {
+ ggml_sycl_pool * pool = nullptr;
+ T * ptr = nullptr;
+ size_t actual_size = 0;
+
+ explicit ggml_sycl_pool_alloc(ggml_sycl_pool & pool) : pool(&pool) {
+ }
+
+ ggml_sycl_pool_alloc(ggml_sycl_pool & pool, size_t size) : pool(&pool) {
+ alloc(size);
+ }
+
+ ~ggml_sycl_pool_alloc() {
+ if (ptr != nullptr) {
+ pool->free(ptr, actual_size);
+ }
+ }
+
+ // size is in number of elements
+ T * alloc(size_t size) {
+ GGML_ASSERT(pool != nullptr);
+ GGML_ASSERT(ptr == nullptr);
+ ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
+ return ptr;
+ }
+
+ T * alloc(ggml_sycl_pool & pool, size_t size) {
+ this->pool = &pool;
+ return alloc(size);
+ }
+
+ T * get() {
+ return ptr;
+ }
+
+ ggml_sycl_pool_alloc() = default;
+ ggml_sycl_pool_alloc(const ggml_sycl_pool_alloc &) = delete;
+ ggml_sycl_pool_alloc(ggml_sycl_pool_alloc &&) = delete;
+ ggml_sycl_pool_alloc& operator=(const ggml_sycl_pool_alloc &) = delete;
+ ggml_sycl_pool_alloc& operator=(ggml_sycl_pool_alloc &&) = delete;
+};
+
+// backend interface
+
+struct ggml_tensor_extra_gpu {
+ void* data_device[GGML_SYCL_MAX_DEVICES]; // 1 pointer for each device for split
+ // tensors
+ dpct::event_ptr events[GGML_SYCL_MAX_DEVICES]
+ [GGML_SYCL_MAX_STREAMS]; // events for synchronizing multiple GPUs
+};
+
+struct ggml_backend_sycl_context {
+ int device;
+ std::string name;
+
+ queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } };
+
+ explicit ggml_backend_sycl_context(int device) :
+ device(device),
+ name(GGML_SYCL_NAME + std::to_string(device)) {
+ }
+
+ queue_ptr stream(int device, int stream) {
+ if (qptrs[device][stream] == nullptr) {
+ qptrs[device][stream] = &(dpct::get_device(device).default_queue());
+ }
+ return qptrs[device][stream];
+ }
+
+ queue_ptr stream() {
+ return stream(device, 0);
+ }
+
+ // pool
+ std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
+
+ static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
+
+ ggml_sycl_pool & pool(int device) {
+ if (pools[device] == nullptr) {
+ pools[device] = new_pool_for_device(stream(device,0), device);
+ }
+ return *pools[device];
+ }
+
+ ggml_sycl_pool & pool() {
+ return pool(device);
+ }
+};
+
+// common device functions
+
+static __dpct_inline__ float warp_reduce_sum(float x,
+ const sycl::nd_item<3>& item_ct1) {
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ /*
+ DPCT1096:98: The right-most dimension of the work-group used in the SYCL
+ kernel that calls this function may be less than "32". The function
+ "dpct::permute_sub_group_by_xor" may return an unexpected result on the
+ CPU device. Modify the size of the work-group to ensure that the value
+ of the right-most dimension is a multiple of "32".
+ */
+ x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask);
+ }
+ return x;
+}
+
+static __dpct_inline__ sycl::float2
+warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ a.x() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.x(),
+ mask);
+ a.y() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.y(),
+ mask);
+ }
+ return a;
+}
+
+static __dpct_inline__ float warp_reduce_max(float x,
+ const sycl::nd_item<3>& item_ct1) {
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ /*
+ DPCT1096:97: The right-most dimension of the work-group used in the SYCL
+ kernel that calls this function may be less than "32". The function
+ "dpct::permute_sub_group_by_xor" may return an unexpected result on the
+ CPU device. Modify the size of the work-group to ensure that the value
+ of the right-most dimension is a multiple of "32".
+ */
+ x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
+ item_ct1.get_sub_group(), x, mask));
+ }
+ return x;
+}
+
+// Helper for vec loading aligned data
+template <typename Tp, int n>
+inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
+ return *reinterpret_cast<const sycl::vec<Tp, n>*>(aligned_ptr);
+}
+
+// Helper for accessing pointers with no warnings
+template <typename Tp, int dim>
+static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
+ return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
+}
+
+#endif // GGML_SYCL_COMMON_HPP
diff --git a/ggml/src/ggml-sycl/concat.cpp b/ggml/src/ggml-sycl/concat.cpp
new file mode 100644
index 00000000..632eedb9
--- /dev/null
+++ b/ggml/src/ggml-sycl/concat.cpp
@@ -0,0 +1,195 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#include "concat.hpp"
+#include "common.hpp"
+
+static void concat_f32_dim0(const float *x, const float *y, float *dst,
+ const int ne0, const int ne00,
+ const sycl::nd_item<3> &item_ct1) {
+ int nidx = item_ct1.get_local_id(2) +
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
+ if (nidx >= ne0) {
+ return;
+ }
+ // operation
+ int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
+ item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
+ if (nidx < ne00) { // src0
+ int offset_src = nidx + item_ct1.get_group(1) * ne00 +
+ item_ct1.get_group(0) * ne00 * item_ct1.get_group_range(1);
+ dst[offset_dst] = x[offset_src];
+ } else {
+ int offset_src =
+ nidx - ne00 + item_ct1.get_group(1) * (ne0 - ne00) +
+ item_ct1.get_group(0) * (ne0 - ne00) * item_ct1.get_group_range(1);
+ dst[offset_dst] = y[offset_src];
+ }
+}
+
+static void concat_f32_dim1(const float *x, const float *y, float *dst,
+ const int ne0, const int ne01,
+ const sycl::nd_item<3> &item_ct1) {
+ int nidx = item_ct1.get_local_id(2) +
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
+ if (nidx >= ne0) {
+ return;
+ }
+ // operation
+ int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
+ item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
+ if (item_ct1.get_group(1) < ne01) { // src0
+ int offset_src =
+ nidx + item_ct1.get_group(1) * ne0 + item_ct1.get_group(0) * ne0 * ne01;
+ dst[offset_dst] = x[offset_src];
+ } else {
+ int offset_src =
+ nidx + (item_ct1.get_group(1) - ne01) * ne0 +
+ item_ct1.get_group(0) * ne0 * (item_ct1.get_group_range(1) - ne01);
+ dst[offset_dst] = y[offset_src];
+ }
+}
+
+static void concat_f32_dim2(const float *x, const float *y, float *dst,
+ const int ne0, const int ne02,
+ const sycl::nd_item<3> &item_ct1) {
+ int nidx = item_ct1.get_local_id(2) +
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
+ if (nidx >= ne0) {
+ return;
+ }
+ // operation
+ int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
+ item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
+ if (item_ct1.get_group(0) < ne02) { // src0
+ int offset_src = nidx + item_ct1.get_group(1) * ne0 +
+ item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
+ dst[offset_dst] = x[offset_src];
+ } else {
+ int offset_src =
+ nidx + item_ct1.get_group(1) * ne0 +
+ (item_ct1.get_group(0) - ne02) * ne0 * item_ct1.get_group_range(1);
+ dst[offset_dst] = y[offset_src];
+ }
+}
+
+static void concat_f32_sycl(const float *x, const float *y, float *dst,
+ int ne00, int ne01, int ne02, int ne0, int ne1,
+ int ne2, int dim, queue_ptr stream) {
+ int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
+ sycl::range<3> gridDim(ne2, ne1, num_blocks);
+ switch (dim) {
+ case 0:
+ stream->parallel_for(
+ sycl::nd_range<3>(gridDim *
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
+ });
+ break;
+ case 1:
+ stream->parallel_for(
+ sycl::nd_range<3>(gridDim *
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
+ });
+ break;
+ default:
+ stream->parallel_for(
+ sycl::nd_range<3>(gridDim *
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
+ });
+ break;
+ }
+}
+
+// non-contiguous kernel (slow)
+static void concat_f32_sycl_non_cont(
+ queue_ptr stream, const char *src0, const char *src1, char *dst,
+ int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, uint64_t nb00,
+ uint64_t nb01, uint64_t nb02, uint64_t nb03, int64_t /*ne10*/,
+ int64_t /*ne11*/, int64_t /*ne12*/, int64_t /*ne13*/, uint64_t nb10,
+ uint64_t nb11, uint64_t nb12, uint64_t nb13, int64_t ne0, int64_t ne1,
+ int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
+ uint64_t nb3, int32_t dim) {
+ sycl::range<3> gridDim(ne3, ne2, ne1);
+ stream->parallel_for(
+ sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)),
+ [=](sycl::nd_item<3> item_ct1) {
+ int64_t i3 = item_ct1.get_group(0);
+ int64_t i2 = item_ct1.get_group(1);
+ int64_t i1 = item_ct1.get_group(2);
+
+ int64_t o[4] = {0, 0, 0, 0};
+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
+
+ const float *x;
+
+ for (int i0 = item_ct1.get_local_id(2); i0 < ne0;
+ i0 += item_ct1.get_local_range(2)) {
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 +
+ (i0)*nb00);
+ } else {
+ x = (const float *)(src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 +
+ (i1 - o[1]) * nb11 + (i0 - o[0]) * nb10);
+ }
+
+ float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
+
+ *y = *x;
+ }
+ });
+}
+
+void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst) {
+ queue_ptr stream = ctx.stream();
+
+ const int32_t dim = ((int32_t *)dst->op_params)[0];
+
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
+ const float *src0_d = (const float *)src0->data;
+ const float *src1_d = (const float *)src1->data;
+
+ float *dst_d = (float *)dst->data;
+
+ if (dim != 3) {
+ for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+ concat_f32_sycl(
+ src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4),
+ dst_d + i3 * (dst->nb[3] / 4), src0->ne[0], src0->ne[1],
+ src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
+ }
+ } else {
+ const size_t size0 = ggml_nbytes(src0);
+ const size_t size1 = ggml_nbytes(src1);
+
+ SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait()));
+ SYCL_CHECK(CHECK_TRY_ERROR(
+ stream->memcpy(dst_d + size0 / 4, src1_d, size1).wait()));
+ }
+ } else
+ concat_f32_sycl_non_cont(
+ stream, (const char *)src0->data, (const char *)src1->data,
+ (char *)dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
+ src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1],
+ src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2],
+ dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
+}
diff --git a/ggml/src/ggml-sycl/concat.hpp b/ggml/src/ggml-sycl/concat.hpp
new file mode 100644
index 00000000..5a04feaa
--- /dev/null
+++ b/ggml/src/ggml-sycl/concat.hpp
@@ -0,0 +1,21 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_CONCAT_HPP
+#define GGML_SYCL_CONCAT_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst);
+
+#endif // GGML_SYCL_CONCAT_HPP
diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp
new file mode 100644
index 00000000..39c28753
--- /dev/null
+++ b/ggml/src/ggml-sycl/convert.cpp
@@ -0,0 +1,547 @@
+#include "convert.hpp"
+#include "dequantize.hpp"
+#include "presets.hpp"
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2));
+
+ if (i >= k) {
+ return;
+ }
+
+ const int ib = i/qk; // block index
+ const int iqs = (i%qk)/qr; // quant index
+ const int iybs = i - i%qk; // y block start index
+ const int y_offset = qr == 1 ? 1 : qk/2;
+
+ // dequantize
+ dfloat2 v;
+ dequantize_kernel(vx, ib, iqs, v);
+
+ y[iybs + iqs + 0] = v.x();
+ y[iybs + iqs + y_offset] = v.y();
+}
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block_sycl(const void *__restrict__ vx,
+ dst_t *__restrict__ y, const int k,
+ dpct::queue_ptr stream) {
+ const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+ stream->parallel_for(
+ sycl::nd_range<3>(
+ sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);
+ });
+ }
+}
+
+template <typename dst_t>
+static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb = k / QK_K;
+#if QK_K == 256
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 64),
+ sycl::range<3>(1, 1, 64)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_q2_K(vx, y, item_ct1);
+ });
+ }
+#else
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_q2_K(vx, y, item_ct1);
+ });
+ }
+
+#endif
+}
+
+template <typename dst_t>
+static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb = k / QK_K;
+#if QK_K == 256
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 64),
+ sycl::range<3>(1, 1, 64)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_q3_K(vx, y, item_ct1);
+ });
+ }
+#else
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_q3_K(vx, y, item_ct1);
+ });
+ }
+#endif
+}
+
+template <typename dst_t>
+static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb32 = k / 32;
+ const int nb = (k + 255) / 256;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_q4_0(vx, y, nb32, item_ct1);
+ });
+ }
+}
+
+template <typename dst_t>
+static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb32 = k / 32;
+ const int nb = (k + 255) / 256;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_q4_1(vx, y, nb32, item_ct1);
+ });
+ }
+}
+
+
+template <typename dst_t>
+static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb = k / QK_K;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
+ });
+ });
+ }
+}
+
+template <typename dst_t>
+static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb = k / QK_K;
+#if QK_K == 256
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 64),
+ sycl::range<3>(1, 1, 64)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_q5_K(vx, y, item_ct1);
+ });
+ }
+#else
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_q5_K(vx, y, item_ct1);
+ });
+ }
+
+#endif
+}
+
+template <typename dst_t>
+static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb = k / QK_K;
+#if QK_K == 256
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 64),
+ sycl::range<3>(1, 1, 64)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_q6_K(vx, y, item_ct1);
+ });
+ }
+#else
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_q6_K(vx, y, item_ct1);
+ });
+ }
+
+#endif
+}
+
+template <typename dst_t>
+static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb = k / QK_K;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_iq1_s(
+ vx, y, item_ct1, iq1s_grid_gpu
+ );
+ });
+ });
+ }
+}
+
+template <typename dst_t>
+static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb = k / QK_K;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_iq1_m(
+ vx, y, item_ct1, iq1s_grid_gpu
+ );
+ });
+ });
+ }
+}
+
+template <typename dst_t>
+static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb = k / QK_K;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_iq2_xxs(
+ vx, y, item_ct1, iq2xxs_grid,
+ ksigns_iq2xs, kmask_iq2xs);
+ });
+ });
+ }
+}
+
+template <typename dst_t>
+static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb = k / QK_K;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_iq2_xs(
+ vx, y, item_ct1, iq2xs_grid,
+ ksigns_iq2xs, kmask_iq2xs);
+ });
+ });
+ }
+}
+
+template <typename dst_t>
+static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb = k / QK_K;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_iq2_s(vx, y, item_ct1);
+ });
+ });
+ }
+}
+
+
+template <typename dst_t>
+static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb = k / QK_K;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_iq3_xxs(
+ vx, y, item_ct1, iq3xxs_grid,
+ ksigns_iq2xs, kmask_iq2xs);
+ });
+ });
+ }
+}
+
+template <typename dst_t>
+static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb = k / QK_K;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_iq3_s(
+ vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
+ });
+ });
+ }
+}
+
+template <typename dst_t>
+static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb = (k + QK_K - 1) / QK_K;
+#if QK_K == 64
+ dequantize_row_iq4_nl_sycl(vx, y, k, stream);
+#else
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ cgh.parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_iq4_xs(vx, y, item_ct1);
+ });
+ });
+ }
+#endif
+}
+
+template <typename dst_t>
+static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
+ dpct::queue_ptr stream) {
+ const int nb = (k + QK_K - 1) / QK_K;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ cgh.parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
+ sycl::range<3>(1, 1, 32),
+ sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_iq4_nl(vx, y, item_ct1);
+ });
+ });
+ }
+}
+
+template <typename src_t, typename dst_t>
+static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+
+ if (i >= k) {
+ return;
+ }
+
+ const src_t * x = (src_t *) vx;
+
+ y[i] = x[i];
+}
+
+template <typename src_t, typename dst_t>
+static void convert_unary_sycl(const void *__restrict__ vx,
+ dst_t *__restrict__ y, const int k,
+ dpct::queue_ptr stream) {
+ const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(
+ sycl::range<3>(1, 1, num_blocks) *
+ sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ convert_unary<src_t>(vx, y, k, item_ct1);
+ });
+ }
+}
+
+to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;
+ case GGML_TYPE_Q4_1:
+ return dequantize_block_sycl<QK4_1, QR4_1, dequantize_q4_1>;
+ case GGML_TYPE_Q5_0:
+ return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
+ case GGML_TYPE_Q5_1:
+ return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q8_0:
+ return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
+ case GGML_TYPE_Q2_K:
+ return dequantize_row_q2_K_sycl;
+ case GGML_TYPE_Q3_K:
+ return dequantize_row_q3_K_sycl;
+ case GGML_TYPE_Q4_K:
+ return dequantize_row_q4_K_sycl;
+ case GGML_TYPE_Q5_K:
+ return dequantize_row_q5_K_sycl;
+ case GGML_TYPE_Q6_K:
+ return dequantize_row_q6_K_sycl;
+ case GGML_TYPE_IQ1_S:
+ return dequantize_row_iq1_s_sycl;
+ case GGML_TYPE_IQ1_M:
+ return dequantize_row_iq1_m_sycl;
+ case GGML_TYPE_IQ2_XXS:
+ return dequantize_row_iq2_xxs_sycl;
+ case GGML_TYPE_IQ2_XS:
+ return dequantize_row_iq2_xs_sycl;
+ case GGML_TYPE_IQ2_S:
+ return dequantize_row_iq2_s_sycl;
+ case GGML_TYPE_IQ3_XXS:
+ return dequantize_row_iq3_xxs_sycl;
+ case GGML_TYPE_IQ3_S:
+ return dequantize_row_iq3_s_sycl;
+ case GGML_TYPE_IQ4_XS:
+ return dequantize_row_iq4_xs_sycl;
+ case GGML_TYPE_IQ4_NL:
+ return dequantize_row_iq4_nl_sycl;
+ case GGML_TYPE_F32:
+ return convert_unary_sycl<float>;
+ default:
+ return nullptr;
+ }
+}
+
+to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ return dequantize_row_q4_0_sycl;
+ case GGML_TYPE_Q4_1:
+ return dequantize_row_q4_1_sycl;
+ case GGML_TYPE_Q5_0:
+ return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
+ case GGML_TYPE_Q5_1:
+ return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q8_0:
+ return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
+ case GGML_TYPE_Q2_K:
+ return dequantize_row_q2_K_sycl;
+ case GGML_TYPE_Q3_K:
+ return dequantize_row_q3_K_sycl;
+ case GGML_TYPE_Q4_K:
+ return dequantize_row_q4_K_sycl;
+ case GGML_TYPE_Q5_K:
+ return dequantize_row_q5_K_sycl;
+ case GGML_TYPE_Q6_K:
+ return dequantize_row_q6_K_sycl;
+ case GGML_TYPE_IQ1_S:
+ return dequantize_row_iq1_s_sycl;
+ case GGML_TYPE_IQ1_M:
+ return dequantize_row_iq1_m_sycl;
+ case GGML_TYPE_IQ2_XXS:
+ return dequantize_row_iq2_xxs_sycl;
+ case GGML_TYPE_IQ2_XS:
+ return dequantize_row_iq2_xs_sycl;
+ case GGML_TYPE_IQ2_S:
+ return dequantize_row_iq2_s_sycl;
+ case GGML_TYPE_IQ3_XXS:
+ return dequantize_row_iq3_xxs_sycl;
+ case GGML_TYPE_IQ3_S:
+ return dequantize_row_iq3_s_sycl;
+ case GGML_TYPE_IQ4_XS:
+ return dequantize_row_iq4_xs_sycl;
+ case GGML_TYPE_IQ4_NL:
+ return dequantize_row_iq4_nl_sycl;
+ case GGML_TYPE_F16:
+ return convert_unary_sycl<sycl::half>;
+ default:
+ return nullptr;
+ }
+}
diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp
new file mode 100644
index 00000000..b1f10d63
--- /dev/null
+++ b/ggml/src/ggml-sycl/convert.hpp
@@ -0,0 +1,27 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_CONVERT_HPP
+#define GGML_SYCL_CONVERT_HPP
+
+#include "common.hpp"
+
+template <typename T>
+using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
+ int k, dpct::queue_ptr stream);
+typedef to_t_sycl_t<float> to_fp32_sycl_t;
+typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
+
+to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type);
+to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type);
+
+#endif // GGML_SYCL_CONVERT_HPP
diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp
new file mode 100644
index 00000000..ed8ad098
--- /dev/null
+++ b/ggml/src/ggml-sycl/dequantize.hpp
@@ -0,0 +1,698 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_DEQUANTIZE_HPP
+#define GGML_SYCL_DEQUANTIZE_HPP
+
+#include "common.hpp"
+
+typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
+
+static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
+ const int iqs, dfloat2 &v) {
+ const block_q4_0 * x = (const block_q4_0 *) vx;
+
+ const dfloat d = x[ib].d;
+
+ const int vui = x[ib].qs[iqs];
+
+ v.x() = vui & 0xF;
+ v.y() = vui >> 4;
+
+#ifdef GGML_SYCL_F16
+ // v = v - {8.0f, 8.0f};
+ // v = v * {d, d};
+ v.s0() = (v.s0() - 8.0f) * d;
+ v.s1() = (v.s1() - 8.0f) * d;
+
+#else
+ v.x() = (v.x() - 8.0f) * d;
+ v.y() = (v.y() - 8.0f) * d;
+#endif // GGML_SYCL_F16
+}
+
+static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
+ const int iqs, dfloat2 &v) {
+ const block_q4_1 * x = (const block_q4_1 *) vx;
+
+ const dfloat d = x[ib].dm[0];
+ const dfloat m = x[ib].dm[1];
+
+ const int vui = x[ib].qs[iqs];
+
+ v.x() = vui & 0xF;
+ v.y() = vui >> 4;
+
+#ifdef GGML_SYCL_F16
+ // v = v * {d, d};
+ // v = v + {m, m};
+ v.s0() = (v.s0() * d) + m;
+ v.s1() = (v.s1() * d) + m;
+
+#else
+ v.x() = (v.x() * d) + m;
+ v.y() = (v.y() * d) + m;
+#endif // GGML_SYCL_F16
+}
+
+static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
+ const int iqs, dfloat2 &v) {
+ const block_q5_0 * x = (const block_q5_0 *) vx;
+
+ const dfloat d = x[ib].d;
+
+ uint32_t qh;
+ memcpy(&qh, x[ib].qh, sizeof(qh));
+
+ const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
+ const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
+
+ v.x() = ((x[ib].qs[iqs] & 0xf) | xh_0);
+ v.y() = ((x[ib].qs[iqs] >> 4) | xh_1);
+
+#ifdef GGML_SYCL_F16
+ // v = v - {16.0f, 16.0f};
+ // v = v * {d, d};
+ v.s0() = (v.s0() - 16.0f) * d;
+ v.s1() = (v.s1() - 16.0f) * d;
+
+#else
+ v.x() = (v.x() - 16.0f) * d;
+ v.y() = (v.y() - 16.0f) * d;
+#endif // GGML_SYCL_F16
+}
+
+static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
+ const int iqs, dfloat2 &v) {
+ const block_q5_1 * x = (const block_q5_1 *) vx;
+
+ const dfloat d = x[ib].dm[0];
+ const dfloat m = x[ib].dm[1];
+
+ uint32_t qh;
+ memcpy(&qh, x[ib].qh, sizeof(qh));
+
+ const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
+ const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
+
+ v.x() = ((x[ib].qs[iqs] & 0xf) | xh_0);
+ v.y() = ((x[ib].qs[iqs] >> 4) | xh_1);
+
+#ifdef GGML_SYCL_F16
+ // v = v * {d, d};
+ // v = v + {m, m};
+ v.s0() = (v.s0() * d) + m;
+ v.s1() = (v.s1() * d) + m;
+#else
+ v.x() = (v.x() * d) + m;
+ v.y() = (v.y() * d) + m;
+#endif // GGML_SYCL_F16
+}
+
+static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
+ const int iqs, dfloat2 &v) {
+ const block_q8_0 * x = (const block_q8_0 *) vx;
+
+ const dfloat d = x[ib].d;
+
+ v.x() = x[ib].qs[iqs + 0];
+ v.y() = x[ib].qs[iqs + 1];
+
+#ifdef GGML_SYCL_F16
+ // v = v * {d, d};
+ v.s0() *= d;
+ v.s1() *= d;
+#else
+ v.x() *= d;
+ v.y() *= d;
+#endif // GGML_SYCL_F16
+}
+
+template<typename dst_t>
+static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
+ const sycl::nd_item<3> &item_ct1) {
+
+ const int i = item_ct1.get_group(2);
+
+ // assume 32 threads
+ const int tid = item_ct1.get_local_id(2);
+ const int il = tid/8;
+ const int ir = tid%8;
+ const int ib = 8*i + ir;
+ if (ib >= nb32) {
+ return;
+ }
+
+ dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+ const block_q4_0 * x = (const block_q4_0 *)vx + ib;
+ const float d = sycl::vec<sycl::half, 1>(x->d)
+ .convert<float, sycl::rounding_mode::automatic>()[0];
+ const float dm = -8*d;
+
+ const uint8_t * q = x->qs + 4*il;
+
+ for (int l = 0; l < 4; ++l) {
+ y[l+ 0] = d * (q[l] & 0xF) + dm;
+ y[l+16] = d * (q[l] >> 4) + dm;
+ }
+}
+
+template<typename dst_t>
+static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
+ const sycl::nd_item<3> &item_ct1) {
+
+ const int i = item_ct1.get_group(2);
+
+ // assume 32 threads
+ const int tid = item_ct1.get_local_id(2);
+ const int il = tid/8;
+ const int ir = tid%8;
+ const int ib = 8*i + ir;
+ if (ib >= nb32) {
+ return;
+ }
+
+ dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+ const block_q4_1 * x = (const block_q4_1 *)vx + ib;
+ const sycl::float2 d =
+ x->dm.convert<float, sycl::rounding_mode::automatic>();
+
+ const uint8_t * q = x->qs + 4*il;
+
+ for (int l = 0; l < 4; ++l) {
+ y[l + 0] = d.x() * (q[l] & 0xF) + d.y();
+ y[l + 16] = d.x() * (q[l] >> 4) + d.y();
+ }
+}
+
+
+//================================== k-quants
+
+template<typename dst_t>
+static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
+ const sycl::nd_item<3> &item_ct1) {
+
+ const int i = item_ct1.get_group(2);
+ const block_q2_K * x = (const block_q2_K *) vx;
+
+ const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+ const int n = tid/32;
+ const int l = tid - 32*n;
+ const int is = 8*n + l/16;
+
+ const uint8_t q = x[i].qs[32*n + l];
+ dst_t * y = yy + i*QK_K + 128*n;
+
+ float dall = x[i].dm[0];
+ float dmin = x[i].dm[1];
+ y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
+ y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
+ y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
+ y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
+#else
+ const int is = tid/16; // 0 or 1
+ const int il = tid%16; // 0...15
+ const uint8_t q = x[i].qs[il] >> (2*is);
+ dst_t * y = yy + i*QK_K + 16*is + il;
+
+ float dall = x[i].dm[0];
+ float dmin = x[i].dm[1];
+ y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
+ y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
+#endif
+
+}
+
+template<typename dst_t>
+static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
+ const sycl::nd_item<3> &item_ct1) {
+
+ const int i = item_ct1.get_group(2);
+ const block_q3_K * x = (const block_q3_K *) vx;
+
+#if QK_K == 256
+ const int r = item_ct1.get_local_id(2) / 4;
+ const int tid = r/2;
+ const int is0 = r%2;
+ const int l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
+ const int n = tid / 4;
+ const int j = tid - 4*n;
+
+ uint8_t m = 1 << (4*n + j);
+ int is = 8*n + 2*j + is0;
+ int shift = 2*j;
+
+ int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
+ is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
+ is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
+ (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
+ float d_all = x[i].d;
+ float dl = d_all * (us - 32);
+
+ dst_t * y = yy + i*QK_K + 128*n + 32*j;
+ const uint8_t * q = x[i].qs + 32*n;
+ const uint8_t * hm = x[i].hmask;
+
+ for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
+#else
+ const int tid = item_ct1.get_local_id(2);
+ const int is = tid/16; // 0 or 1
+ const int il = tid%16; // 0...15
+ const int im = il/8; // 0...1
+ const int in = il%8; // 0...7
+
+ dst_t * y = yy + i*QK_K + 16*is + il;
+
+ const uint8_t q = x[i].qs[il] >> (2*is);
+ const uint8_t h = x[i].hmask[in] >> (2*is + im);
+ const float d = (float)x[i].d;
+
+ if (is == 0) {
+ y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
+ y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
+ } else {
+ y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
+ y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
+ }
+#endif
+
+}
+
+#if QK_K == 256
+static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
+ if (j < 4) {
+ d = q[j] & 63;
+ m = q[j + 4] & 63;
+ } else {
+ d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+ m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
+ }
+}
+#endif
+
+template<typename dst_t>
+static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
+ uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
+ const block_q4_K * x = (const block_q4_K *) vx;
+
+ const int i = item_ct1.get_group(2);
+
+#if QK_K == 256
+ // assume 32 threads
+ const int tid = item_ct1.get_local_id(2);
+ const int il = tid/8;
+ const int ir = tid%8;
+ const int is = 2*il;
+ const int n = 4;
+
+ dst_t * y = yy + i*QK_K + 64*il + n*ir;
+
+ const sycl::half2 dm = x[i].dm;
+ const float dall = dm[0];
+ const float dmin = dm[1];
+
+ if (tid < 12)
+ scales_local[tid] = x[i].scales[tid];
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ uint8_t sc, m;
+ get_scale_min_k4(is + 0, scales_local, sc, m);
+ const float d1 = dall * sc;
+ const float m1 = dmin * m;
+ get_scale_min_k4(is + 1, scales_local, sc, m);
+ const float d2 = dall * sc;
+ const float m2 = dmin * m;
+
+ sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
+ for (int l = 0; l < n; ++l) {
+ y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
+ y[l +32] = d2 * (q_vec[l] >> 4) - m2;
+ }
+#else
+ const int tid = item_ct1.get_local_id(2);
+ const uint8_t * q = x[i].qs;
+ dst_t * y = yy + i*QK_K;
+ const float d = (float)x[i].dm[0];
+ const float m = (float)x[i].dm[1];
+ y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
+ y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
+#endif
+}
+
+template<typename dst_t>
+static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
+ const sycl::nd_item<3> &item_ct1) {
+ const block_q5_K * x = (const block_q5_K *) vx;
+
+ const int i = item_ct1.get_group(2);
+
+#if QK_K == 256
+ // assume 64 threads - this is very slightly better than the one below
+ const int tid = item_ct1.get_local_id(2);
+ const int il = tid/16; // il is in 0...3
+ const int ir = tid%16; // ir is in 0...15
+ const int is = 2*il; // is is in 0...6
+
+ dst_t * y = yy + i*QK_K + 64*il + 2*ir;
+
+ const float dall = x[i].dm[0];
+ const float dmin = x[i].dm[1];
+
+ const uint8_t * ql = x[i].qs + 32*il + 2*ir;
+ const uint8_t * qh = x[i].qh + 2*ir;
+
+ uint8_t sc, m;
+ get_scale_min_k4(is + 0, x[i].scales, sc, m);
+ const float d1 = dall * sc; const float m1 = dmin * m;
+ get_scale_min_k4(is + 1, x[i].scales, sc, m);
+ const float d2 = dall * sc; const float m2 = dmin * m;
+
+ uint8_t hm = 1 << (2*il);
+ y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
+ y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
+ hm <<= 1;
+ y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
+ y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
+#else
+ const int tid = item_ct1.get_local_id(2);
+ const uint8_t q = x[i].qs[tid];
+ const int im = tid/8; // 0...3
+ const int in = tid%8; // 0...7
+ const int is = tid/16; // 0 or 1
+ const uint8_t h = x[i].qh[in] >> im;
+ const float d = x[i].d;
+ dst_t * y = yy + i*QK_K + tid;
+ y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
+ y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
+#endif
+}
+
+template<typename dst_t>
+static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
+ const sycl::nd_item<3> &item_ct1) {
+ const block_q6_K * x = (const block_q6_K *) vx;
+
+ const int i = item_ct1.get_group(2);
+#if QK_K == 256
+
+ // assume 64 threads - this is very slightly better than the one below
+ const int tid = item_ct1.get_local_id(2);
+ const int ip = tid/32; // ip is 0 or 1
+ const int il = tid - 32*ip; // 0...32
+ const int is = 8*ip + il/16;
+
+ dst_t * y = yy + i*QK_K + 128*ip + il;
+
+ const float d = x[i].d;
+
+ const uint8_t * ql = x[i].ql + 64*ip + il;
+ const uint8_t qh = x[i].qh[32*ip + il];
+ const int8_t * sc = x[i].scales + is;
+
+ y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+ y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
+ y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+ y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+#else
+
+ // assume 32 threads
+ const int tid = item_ct1.get_local_id(2);
+ const int ip = tid/16; // 0 or 1
+ const int il = tid - 16*ip; // 0...15
+
+ dst_t * y = yy + i*QK_K + 16*ip + il;
+
+ const float d = x[i].d;
+
+ const uint8_t ql = x[i].ql[16*ip + il];
+ const uint8_t qh = x[i].qh[il] >> (2*ip);
+ const int8_t * sc = x[i].scales;
+
+ y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+ y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+#endif
+}
+
+template<typename dst_t>
+static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
+ const sycl::nd_item<3> &item_ct1,
+ const uint64_t *iq2xxs_grid_ptr,
+ const uint8_t *ksigns_iq2xs_ptr,
+ const uint8_t *kmask_iq2xs_ptr) {
+
+ const int i = item_ct1.get_group(2);
+ const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
+
+ const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+ const int il = tid/8; // 0...3
+ const int ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint16_t * q2 = x[i].qs + 4*ib;
+ const uint8_t * aux8 = (const uint8_t *)q2;
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid_ptr + aux8[il]);
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
+ const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
+ const uint8_t signs = ksigns_iq2xs_ptr[(aux32 >> 7*il) & 127];
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs_ptr[j] ? -1.f : 1.f);
+#else
+ assert(false);
+#endif
+
+}
+
+template<typename dst_t>
+static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy,
+ const sycl::nd_item<3> &item_ct1,
+ const uint64_t *iq2xs_grid,
+ const uint8_t *ksigns_iq2xs,
+ const uint8_t *kmask_iq2xs) {
+
+ const int i = item_ct1.get_group(2);
+ const block_iq2_xs * x = (const block_iq2_xs *) vx;
+
+ const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+ const int il = tid/8; // 0...3
+ const int ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint16_t * q2 = x[i].qs + 4*ib;
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
+ const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+#else
+ assert(false);
+#endif
+
+}
+
+template <typename dst_t>
+__dpct_inline__ static void
+dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
+ const sycl::nd_item<3> &item_ct1) {
+
+ const int i = item_ct1.get_group(2);
+ const block_iq2_s * x = (const block_iq2_s *) vx;
+
+ const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+ const int il = tid/8; // 0...3
+ const int ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
+ const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
+#pragma unroll
+ for (int j = 0; j < 8; ++j)
+ y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+#else
+ assert(false);
+
+#endif
+
+}
+
+template<typename dst_t>
+static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
+ const sycl::nd_item<3> &item_ct1,
+ const uint32_t *iq3xxs_grid,
+ const uint8_t *ksigns_iq2xs,
+ const uint8_t *kmask_iq2xs) {
+
+ const int i = item_ct1.get_group(2);
+ const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
+
+ const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+ const int il = tid/8; // 0...3
+ const int ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint8_t * q3 = x[i].qs + 8*ib;
+ const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
+ const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
+ for (int j = 0; j < 4; ++j) {
+ y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+ y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+ }
+#else
+ assert(false);
+#endif
+
+}
+
+template <typename dst_t>
+__dpct_inline__ static void
+dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
+ const sycl::nd_item<3> &item_ct1,
+ const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
+
+ const int i = item_ct1.get_group(2);
+ const block_iq3_s * x = (const block_iq3_s *) vx;
+
+ const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+ const int il = tid/8; // 0...3
+ const int ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint8_t * qs = x[i].qs + 8*ib;
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
+ const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
+ const uint8_t signs = x[i].signs[4*ib + il];
+#pragma unroll
+ for (int j = 0; j < 4; ++j) {
+ y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+ y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+ }
+#else
+ assert(false);
+#endif
+
+}
+
+template <typename dst_t>
+__dpct_inline__ static void
+dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
+ const sycl::nd_item<3> &item_ct1,
+ const uint32_t *iq1s_grid_gpu) {
+
+ const int i = item_ct1.get_group(2);
+ const block_iq1_s * x = (const block_iq1_s *) vx;
+
+ const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+ const int il = tid/8; // 0...3
+ const int ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
+ const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
+ uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
+ grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
+ grid32[0] &= 0x0f0f0f0f;
+#pragma unroll
+ for (int j = 0; j < 8; ++j) {
+ y[j] = d * (q[j] + delta);
+ }
+#else
+ assert(false);
+#endif
+
+}
+
+template <typename dst_t>
+__dpct_inline__ static void
+dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
+ const sycl::nd_item<3> &item_ct1,
+ const uint32_t *iq1s_grid_gpu) {
+
+ const int i = item_ct1.get_group(2);
+ const block_iq1_m * x = (const block_iq1_m *) vx;
+
+ const int tid = item_ct1.get_local_id(2);
+#if QK_K == 256
+ const int il = tid/8; // 0...3
+ const int ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
+ iq1m_scale_t scale;
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+ const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
+ const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
+ const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
+ uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
+ grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
+ grid32[0] &= 0x0f0f0f0f;
+#pragma unroll
+ for (int j = 0; j < 8; ++j) {
+ y[j] = d * (q[j] + delta);
+ }
+#else
+ assert(false);
+#endif
+
+}
+
+template <typename dst_t>
+__dpct_inline__ static void
+dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
+ const sycl::nd_item<3> &item_ct1) {
+
+ const int i = item_ct1.get_group(2);
+ const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
+
+ const int tid = item_ct1.get_local_id(2);
+ const int il = tid/8; // 0...3
+ const int ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+ const uint8_t * q4 = x[ib].qs + 4*il;
+ const float d = (float)x[ib].d;
+#pragma unroll
+ for (int j = 0; j < 4; ++j) {
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
+ }
+
+}
+
+
+template <typename dst_t>
+__dpct_inline__ static void
+dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i = item_ct1.get_group(2);
+ const block_iq4_xs * x = (const block_iq4_xs *)vx;
+
+ const int tid = item_ct1.get_local_id(2);
+ const int il = tid/8; // 0...3
+ const int ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+ const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
+ const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
+#pragma unroll
+ for (int j = 0; j < 4; ++j) {
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
+ }
+}
+
+
+#endif // GGML_SYCL_DEQUANTIZE_HPP
diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp
new file mode 100644
index 00000000..70a94fc1
--- /dev/null
+++ b/ggml/src/ggml-sycl/dmmv.cpp
@@ -0,0 +1,1023 @@
+#include "convert.hpp"
+#include "dmmv.hpp"
+#include "dequantize.hpp"
+#include "presets.hpp"
+
+
+static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
+ const sycl::half *x = (const sycl::half *)vx;
+
+ // automatic half -> float type cast if dfloat == float
+ v.x() = x[ib + iqs + 0];
+ v.y() = x[ib + iqs + 1];
+}
+
+static void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
+ const float * x = (const float *) vx;
+
+ // automatic half -> float type cast if dfloat == float
+ v.x() = x[ib + iqs + 0];
+ v.y() = x[ib + iqs + 1];
+}
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
+static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
+ const sycl::nd_item<3> &item_ct1) {
+ // qk = quantized weights per x block
+ // qr = number of quantized weights per data value in x block
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+
+ if (row >= nrows) {
+ return;
+ }
+
+ const int tid = item_ct1.get_local_id(2);
+
+ const int iter_stride = 2*GGML_SYCL_DMMV_X;
+ const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
+ const int y_offset = qr == 1 ? 1 : qk/2;
+
+// partial sum for each thread
+#ifdef GGML_SYCL_F16
+ sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
+#else
+ float tmp = 0.0f;
+#endif // GGML_SYCL_F16
+
+ for (int i = 0; i < ncols; i += iter_stride) {
+ const int col = i + vals_per_iter*tid;
+ const int ib = (row*ncols + col)/qk; // x block index
+ const int iqs = (col%qk)/qr; // x quant index
+ const int iybs = col - col%qk; // y block start index
+
+// processing >2 values per i iter is faster for fast GPUs
+#pragma unroll
+ for (int j = 0; j < vals_per_iter; j += 2) {
+ // process 2 vals per j iter
+
+ // dequantize
+ // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
+ dfloat2 v;
+ dequantize_kernel(vx, ib, iqs + j/qr, v);
+
+ // matrix multiplication
+ // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
+#ifdef GGML_SYCL_F16
+ dfloat2 t1{y[iybs + iqs + j / qr + 0],
+ y[iybs + iqs + j / qr + y_offset]};
+
+ tmp += v * t1;
+#else
+ tmp += v.x() * y[iybs + iqs + j / qr + 0];
+ tmp += v.y() * y[iybs + iqs + j / qr + y_offset];
+#endif // GGML_SYCL_F16
+ }
+ }
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (tid == 0) {
+#ifdef GGML_SYCL_F16
+ dst[row] = tmp.x() + tmp.y();
+#else
+ dst[row] = tmp;
+#endif // GGML_SYCL_F16
+ }
+}
+
+static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
+ nrows, item_ct1);
+ });
+ }
+}
+
+/*
+DPCT1110:4: The total declared local variable size in device function
+dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register
+pressure. Consult with your hardware vendor to find the total register size
+available and adjust the code, or use smaller sub-group size to avoid high
+register pressure.
+*/
+static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
+ const float *__restrict__ yy,
+ float *__restrict__ dst,
+ const int ncols, int nrows,
+ const sycl::nd_item<3> &item_ct1) {
+
+ static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
+
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+ if (row > nrows) return;
+
+ const int num_blocks_per_row = ncols / QK_K;
+ const int ib0 = row*num_blocks_per_row;
+
+ const block_q2_K * x = (const block_q2_K *)vx + ib0;
+
+ float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+ const int tid =
+ item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...15
+ const int ix =
+ item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
+
+ const int step = 16/K_QUANTS_PER_ITERATION;
+
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const int in = tid - step*im; // 0...15 or 0...7
+
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
+ const int q_offset = 32*im + l0;
+ const int s_offset = 8*im;
+ const int y_offset = 128*im + l0;
+
+ uint32_t aux[4];
+ const uint8_t * d = (const uint8_t *)aux;
+ const uint8_t * m = (const uint8_t *)(aux + 2);
+
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+ const float * y = yy + i * QK_K + y_offset;
+ const uint8_t * q = x[i].qs + q_offset;
+
+ const float dall = x[i].dm[0];
+ const float dmin = x[i].dm[1];
+
+ const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
+ aux[0] = a[0] & 0x0f0f0f0f;
+ aux[1] = a[1] & 0x0f0f0f0f;
+ aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
+ aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
+
+ float sum1 = 0, sum2 = 0;
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+ sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+ + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+ + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
+ + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
+ + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
+ + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
+ + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
+ +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
+ sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
+ + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
+
+ }
+ tmp += dall * sum1 - dmin * sum2;
+
+ }
+#else
+ const int tid = item_ct1.get_local_id(2) /
+ (2 * K_QUANTS_PER_ITERATION); // 0...15 or 0...7
+ const int ix = item_ct1.get_local_id(2) %
+ (2 * K_QUANTS_PER_ITERATION); // 0....1 or 0...3
+ const int offset = tid * K_QUANTS_PER_ITERATION;
+
+ uint32_t uaux[2];
+ const uint8_t * d = (const uint8_t *)uaux;
+
+
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+ const float * y = yy + i * QK_K + offset;
+ const uint8_t * q = x[i].qs + offset;
+ const uint32_t * s = (const uint32_t *)x[i].scales;
+
+ uaux[0] = s[0] & 0x0f0f0f0f;
+ uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
+
+ const sycl::float2 dall =
+ x[i].dm.convert<float, sycl::rounding_mode::automatic>();
+
+ float sum1 = 0, sum2 = 0;
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+ const uint8_t ql = q[l];
+ sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
+ + y[l+16] * d[1] * ((ql >> 2) & 3)
+ + y[l+32] * d[2] * ((ql >> 4) & 3)
+ + y[l+48] * d[3] * ((ql >> 6) & 3);
+ sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
+ }
+ tmp += dall.x() * sum1 - dall.y() * sum2;
+ }
+
+#endif
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (item_ct1.get_local_id(2) == 0) {
+ dst[row] = tmp;
+ }
+}
+
+/*
+DPCT1110:5: The total declared local variable size in device function
+dequantize_mul_mat_vec_q3_k exceeds 128 bytes and may cause high register
+pressure. Consult with your hardware vendor to find the total register size
+available and adjust the code, or use smaller sub-group size to avoid high
+register pressure.
+*/
+static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
+ const float *__restrict__ yy,
+ float *__restrict__ dst,
+ const int ncols, int nrows,
+ const sycl::nd_item<3> &item_ct1) {
+
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+ if (row > nrows) return;
+
+ const int num_blocks_per_row = ncols / QK_K;
+ const int ib0 = row*num_blocks_per_row;
+
+ const block_q3_K * x = (const block_q3_K *)vx + ib0;
+
+ float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+
+ const uint16_t kmask1 = 0x0303;
+ const uint16_t kmask2 = 0x0f0f;
+
+ const int tid =
+ item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const int ix =
+ item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
+
+ const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
+ const int step = 16/K_QUANTS_PER_ITERATION;
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const int in = tid - step*im; // 0....15 or 0...7
+
+ const uint8_t m = 1 << (4*im);
+
+ const int l0 = n*in; // 0...15 or 0...14 in steps of 2
+ const int q_offset = 32*im + l0;
+ const int y_offset = 128*im + l0;
+
+ uint16_t utmp[4];
+ const int8_t * s = (const int8_t *)utmp;
+
+ const uint16_t s_shift = 4*im;
+
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+ const float * y = yy + i * QK_K + y_offset;
+ const uint8_t * q = x[i].qs + q_offset;
+ const uint8_t * h = x[i].hmask + l0;
+
+ const uint16_t * a = (const uint16_t *)x[i].scales;
+ utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
+ utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
+ utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
+ utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
+
+ const float d = x[i].d;
+
+ float sum = 0;
+ for (int l = 0; l < n; ++l) {
+ sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
+ + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
+ + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
+ + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
+ sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
+ + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
+ + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
+ + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
+ }
+ tmp += d * sum;
+
+ }
+#else
+
+ const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
+ const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
+ const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
+ const int in = offset/8; // 0 or 1
+ const int im = offset%8; // 0...7
+
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+ const float * y = yy + i * QK_K + offset;
+ const uint8_t * q = x[i].qs + offset;
+ const uint8_t * s = x[i].scales;
+
+ const float dall = (float)x[i].d;
+
+ float sum = 0;
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+ const uint8_t hl = x[i].hmask[im+l] >> in;
+ const uint8_t ql = q[l];
+ sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
+ + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
+ + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
+ + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
+ }
+ tmp += sum;
+ }
+#endif
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (item_ct1.get_local_id(2) == 0) {
+ dst[row] = tmp;
+ }
+}
+
+/*
+DPCT1110:6: The total declared local variable size in device function
+dequantize_mul_mat_vec_q4_k exceeds 128 bytes and may cause high register
+pressure. Consult with your hardware vendor to find the total register size
+available and adjust the code, or use smaller sub-group size to avoid high
+register pressure.
+*/
+static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
+ const float *__restrict__ yy,
+ float *__restrict__ dst,
+ const int ncols, int nrows,
+ const sycl::nd_item<3> &item_ct1) {
+
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+ if (row > nrows) return;
+ const int num_blocks_per_row = ncols / QK_K;
+ const int ib0 = row*num_blocks_per_row;
+
+ const block_q4_K * x = (const block_q4_K *)vx + ib0;
+
+#if QK_K == 256
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
+ const int tid =
+ item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const int ix =
+ item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
+
+ const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
+
+ const int il = tid/step; // 0...3
+ const int ir = tid - step*il; // 0...7 or 0...3
+ const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
+
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+ const int in = il%2;
+
+ const int l0 = n*(2*ir + in);
+ const int q_offset = 32*im + l0;
+ const int y_offset = 64*im + l0;
+
+ uint16_t aux[4];
+ const uint8_t * sc = (const uint8_t *)aux;
+
+#if K_QUANTS_PER_ITERATION == 2
+ uint32_t q32[4];
+ const uint8_t * q4 = (const uint8_t *)q32;
+#else
+ uint16_t q16[4];
+ const uint8_t * q4 = (const uint8_t *)q16;
+#endif
+
+ float tmp = 0; // partial sum for thread in warp
+
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+ const float * y1 = yy + i*QK_K + y_offset;
+ const float * y2 = y1 + 128;
+
+ const float dall = x[i].dm[0];
+ const float dmin = x[i].dm[1];
+
+ const uint16_t * a = (const uint16_t *)x[i].scales;
+ aux[0] = a[im+0] & kmask1;
+ aux[1] = a[im+2] & kmask1;
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+#if K_QUANTS_PER_ITERATION == 2
+ const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
+ const uint32_t * q2 = q1 + 16;
+
+ q32[0] = q1[0] & 0x0f0f0f0f;
+ q32[1] = q1[0] & 0xf0f0f0f0;
+ q32[2] = q2[0] & 0x0f0f0f0f;
+ q32[3] = q2[0] & 0xf0f0f0f0;
+
+ sycl::float4 s = {0.f, 0.f, 0.f, 0.f};
+ float smin = 0;
+ for (int l = 0; l < 4; ++l) {
+ s.x() += y1[l] * q4[l + 0]; s.y() += y1[l + 32] * q4[l + 4];
+ s.z() += y2[l] * q4[l + 8]; s.w() += y2[l + 32] * q4[l + 12];
+ smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+ }
+ tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f +
+ s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) -
+ dmin * smin;
+#else
+ const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
+ const uint16_t * q2 = q1 + 32;
+
+ q16[0] = q1[0] & 0x0f0f;
+ q16[1] = q1[0] & 0xf0f0;
+ q16[2] = q2[0] & 0x0f0f;
+ q16[3] = q2[0] & 0xf0f0;
+
+ float4 s = {0.f, 0.f, 0.f, 0.f};
+ float smin = 0;
+ for (int l = 0; l < 2; ++l) {
+ s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
+ s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
+ smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+ }
+ tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
+#endif
+
+ }
+#else
+ const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15
+ const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);
+
+ const int step = tid * K_QUANTS_PER_ITERATION;
+
+ uint16_t aux16[2];
+ const uint8_t * s = (const uint8_t *)aux16;
+
+ float tmp = 0;
+
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+ const uint8_t * q = x[i].qs + step;
+ const float * y = yy + i*QK_K + step;
+ const uint16_t * a = (const uint16_t *)x[i].scales;
+ aux16[0] = a[0] & 0x0f0f;
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
+ const float d = (float)x[i].dm[0];
+ const float m = (float)x[i].dm[1];
+ float sum = 0.f;
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+ sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
+ + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
+ + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
+ + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
+ }
+ tmp += sum;
+ }
+
+#endif
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (tid == 0) {
+ dst[row] = tmp;
+ }
+}
+
+/*
+DPCT1110:7: The total declared local variable size in device function
+dequantize_mul_mat_vec_q5_k exceeds 128 bytes and may cause high register
+pressure. Consult with your hardware vendor to find the total register size
+available and adjust the code, or use smaller sub-group size to avoid high
+register pressure.
+*/
+static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
+ const float *__restrict__ yy,
+ float *__restrict__ dst,
+ const int ncols,
+ const sycl::nd_item<3> &item_ct1) {
+
+ const int row = item_ct1.get_group(2);
+ const int num_blocks_per_row = ncols / QK_K;
+ const int ib0 = row*num_blocks_per_row;
+
+ const block_q5_K * x = (const block_q5_K *)vx + ib0;
+
+ float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+ const uint16_t kmask1 = 0x3f3f;
+ const uint16_t kmask2 = 0x0f0f;
+ const uint16_t kmask3 = 0xc0c0;
+
+ const int tid = item_ct1.get_local_id(2) / 2; // 0...15
+ const int ix = item_ct1.get_local_id(2) % 2;
+
+ const int il = tid/4; // 0...3
+ const int ir = tid - 4*il;// 0...3
+ const int n = 2;
+
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+ const int in = il%2;
+
+ const int l0 = n*(2*ir + in);
+ const int q_offset = 32*im + l0;
+ const int y_offset = 64*im + l0;
+
+ const uint8_t hm1 = 1 << (2*im);
+ const uint8_t hm2 = hm1 << 4;
+
+ uint16_t aux[4];
+ const uint8_t * sc = (const uint8_t *)aux;
+
+ uint16_t q16[8];
+ const uint8_t * q4 = (const uint8_t *)q16;
+
+ for (int i = ix; i < num_blocks_per_row; i += 2) {
+
+ const uint8_t * ql1 = x[i].qs + q_offset;
+ const uint8_t * qh = x[i].qh + l0;
+ const float * y1 = yy + i*QK_K + y_offset;
+ const float * y2 = y1 + 128;
+
+ const float dall = x[i].dm[0];
+ const float dmin = x[i].dm[1];
+
+ const uint16_t * a = (const uint16_t *)x[i].scales;
+ aux[0] = a[im+0] & kmask1;
+ aux[1] = a[im+2] & kmask1;
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+ sycl::float4 sum = {0.f, 0.f, 0.f, 0.f};
+ float smin = 0;
+ const uint16_t * q1 = (const uint16_t *)ql1;
+ const uint16_t * q2 = q1 + 32;
+ q16[0] = q1[0] & 0x0f0f;
+ q16[1] = q1[8] & 0x0f0f;
+ q16[2] = (q1[0] >> 4) & 0x0f0f;
+ q16[3] = (q1[8] >> 4) & 0x0f0f;
+ q16[4] = q2[0] & 0x0f0f;
+ q16[5] = q2[8] & 0x0f0f;
+ q16[6] = (q2[0] >> 4) & 0x0f0f;
+ q16[7] = (q2[8] >> 4) & 0x0f0f;
+ for (int l = 0; l < n; ++l) {
+ sum.x() +=
+ y1[l + 0] * (q4[l + 0] + (qh[l + 0] & (hm1 << 0) ? 16 : 0)) +
+ y1[l + 16] * (q4[l + 2] + (qh[l + 16] & (hm1 << 0) ? 16 : 0));
+ sum.y() +=
+ y1[l + 32] * (q4[l + 4] + (qh[l + 0] & (hm1 << 1) ? 16 : 0)) +
+ y1[l + 48] * (q4[l + 6] + (qh[l + 16] & (hm1 << 1) ? 16 : 0));
+ sum.z() +=
+ y2[l + 0] * (q4[l + 8] + (qh[l + 0] & (hm2 << 0) ? 16 : 0)) +
+ y2[l + 16] * (q4[l + 10] + (qh[l + 16] & (hm2 << 0) ? 16 : 0));
+ sum.w() +=
+ y2[l + 32] * (q4[l + 12] + (qh[l + 0] & (hm2 << 1) ? 16 : 0)) +
+ y2[l + 48] * (q4[l + 14] + (qh[l + 16] & (hm2 << 1) ? 16 : 0));
+ smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+ + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
+ }
+ tmp += dall * (sum.x() * sc[0] + sum.y() * sc[1] + sum.z() * sc[4] +
+ sum.w() * sc[5]) -
+ dmin * smin;
+ }
+
+#else
+ const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15
+ const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);
+ const int step = tid * K_QUANTS_PER_ITERATION;
+ const int im = step/8;
+ const int in = step%8;
+
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+ const uint8_t * q = x[i].qs + step;
+ const int8_t * s = x[i].scales;
+ const float * y = yy + i*QK_K + step;
+ const float d = x[i].d;
+ float sum = 0.f;
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+ const uint8_t h = x[i].qh[in+j] >> im;
+ sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
+ + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
+ + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
+ + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
+ }
+ tmp += sum;
+ }
+#endif
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (item_ct1.get_local_id(2) == 0) {
+ dst[row] = tmp;
+ }
+}
+
+static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows,
+ const sycl::nd_item<3> &item_ct1) {
+
+ static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
+
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+ if (row > nrows) return;
+
+ const int num_blocks_per_row = ncols / QK_K;
+ const int ib0 = row*num_blocks_per_row;
+
+ const block_q6_K * x = (const block_q6_K *)vx + ib0;
+
+#if QK_K == 256
+
+ const int tid =
+ item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const int ix =
+ item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1
+
+ const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
+
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const int in = tid - step*im; // 0...15 or 0...7
+
+#if K_QUANTS_PER_ITERATION == 1
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
+ const int is = 0;
+#else
+ const int l0 = 4 * in; // 0, 4, 8, ..., 28
+ const int is = in / 4;
+#endif
+ const int ql_offset = 64*im + l0;
+ const int qh_offset = 32*im + l0;
+ const int s_offset = 8*im + is;
+ const int y_offset = 128*im + l0;
+
+ float tmp = 0; // partial sum for thread in warp
+
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+ const float * y = yy + i * QK_K + y_offset;
+ const uint8_t * ql = x[i].ql + ql_offset;
+ const uint8_t * qh = x[i].qh + qh_offset;
+ const int8_t * s = x[i].scales + s_offset;
+
+ const float d = x[i].d;
+
+#if K_QUANTS_PER_ITERATION == 1
+ float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+ + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+ + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+ + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+ + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+ + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+ + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+ +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
+ tmp += sum;
+#else
+ float sum = 0;
+ for (int l = 0; l < 4; ++l) {
+ sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+ + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+ + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+ + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
+ }
+ tmp += sum;
+#endif
+
+ }
+
+#else
+
+ const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7
+ const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3
+
+ const int step = tid * K_QUANTS_PER_ITERATION;
+
+ float tmp = 0; // partial sum for thread in warp
+
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+ const float * y = yy + i * QK_K + step;
+ const uint8_t * ql = x[i].ql + step;
+ const uint8_t * qh = x[i].qh + step;
+ const int8_t * s = x[i].scales;
+
+ const float d = x[i+0].d;
+
+ float sum = 0;
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+ sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
+ + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
+ + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
+ + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
+ }
+ tmp += sum;
+
+ }
+
+#endif
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (tid == 0) {
+ dst[row] = tmp;
+ }
+}
+
+
+static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
+ vx, y, dst, ncols, nrows, item_ct1);
+ });
+ }
+}
+
+static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
+ vx, y, dst, ncols, nrows, item_ct1);
+ });
+ }
+}
+
+static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
+ vx, y, dst, ncols, nrows, item_ct1);
+ });
+ }
+}
+
+static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
+ vx, y, dst, ncols, nrows, item_ct1);
+ });
+ }
+}
+
+static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
+ vx, y, dst, ncols, nrows, item_ct1);
+ });
+ }
+}
+
+static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
+ dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
+ });
+}
+
+static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
+ dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
+ });
+}
+
+static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
+ dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
+ });
+}
+
+static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
+ dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
+ });
+}
+
+static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
+ const int block_num_y = (nrows + ny - 1) / ny;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
+ dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
+ });
+}
+
+void ggml_sycl_op_dequantize_mul_mat_vec(
+ ggml_backend_sycl_context & ctx,
+ const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+ const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
+ float *dst_dd_i, const int64_t row_low, const int64_t row_high,
+ const int64_t src1_ncols, const int64_t src1_padded_row_size,
+ const dpct::queue_ptr &stream) {
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t row_diff = row_high - row_low;
+
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
+#ifdef GGML_SYCL_F16
+ ggml_sycl_pool_alloc<sycl::half> src1_dfloat_a(ctx.pool());
+ sycl::half *src1_dfloat = nullptr; // dfloat == half
+
+ bool src1_convert_f16 =
+ src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+ src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
+ src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+
+ if (src1_convert_f16) {
+ src1_dfloat = src1_dfloat_a.alloc(ne00);
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
+ GGML_ASSERT(to_fp16_sycl != nullptr);
+ to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream);
+ }
+#else
+ const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
+#endif // GGML_SYCL_F16
+
+ switch (src0->type) {
+ case GGML_TYPE_Q4_0:
+ dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q4_1:
+ dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_0:
+ dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_1:
+ dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q8_0:
+ dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q2_K:
+ dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q3_K:
+ dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q4_K:
+ dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_K:
+ dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q6_K:
+ dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_F16:
+ convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
+ break;
+ default:
+ printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
+ GGML_ASSERT(false);
+ break;
+ }
+
+ (void) src1;
+ (void) dst;
+ (void) src1_ddq_i;
+ (void) src1_ncols;
+ (void) src1_padded_row_size;
+}
diff --git a/ggml/src/ggml-sycl/dmmv.hpp b/ggml/src/ggml-sycl/dmmv.hpp
new file mode 100644
index 00000000..bd837356
--- /dev/null
+++ b/ggml/src/ggml-sycl/dmmv.hpp
@@ -0,0 +1,27 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_DMMV_HPP
+#define GGML_SYCL_DMMV_HPP
+
+#include "common.hpp"
+
+
+void ggml_sycl_op_dequantize_mul_mat_vec(
+ ggml_backend_sycl_context & ctx,
+ const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+ const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
+ float *dst_dd_i, const int64_t row_low, const int64_t row_high,
+ const int64_t src1_ncols, const int64_t src1_padded_row_size,
+ const dpct::queue_ptr &stream);
+
+#endif // GGML_SYCL_DMMV_HPP
diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp
new file mode 100644
index 00000000..4aaa76bf
--- /dev/null
+++ b/ggml/src/ggml-sycl/dpct/helper.hpp
@@ -0,0 +1,3011 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_DPCT_HELPER_HPP
+#define GGML_SYCL_DPCT_HELPER_HPP
+
+#include <sycl/sycl.hpp>
+#include <sycl/half_type.hpp>
+#include <oneapi/mkl.hpp>
+#include <map>
+
+#include "ggml.h"
+
+#if defined(__linux__)
+#include <sys/mman.h>
+#elif defined(_WIN64)
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif
+#include <windows.h>
+#else
+#error "Only support Windows and Linux."
+#endif
+
+#if defined(__linux__)
+#include <unistd.h>
+#include <sys/syscall.h>
+#endif
+#if defined(_WIN64)
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif
+#include <windows.h>
+#endif
+
+#define DPCT_COMPATIBILITY_TEMP (900)
+
+#if defined(_MSC_VER)
+#define __dpct_align__(n) __declspec(align(n))
+#define __dpct_inline__ __forceinline
+#else
+#define __dpct_align__(n) __attribute__((aligned(n)))
+#define __dpct_inline__ __inline__ __attribute__((always_inline))
+#endif
+
+#if defined(_MSC_VER)
+#define __dpct_noinline__ __declspec(noinline)
+#else
+#define __dpct_noinline__ __attribute__((noinline))
+#endif
+
+inline std::string get_device_type_name(const sycl::device &Device) {
+ auto DeviceType = Device.get_info<sycl::info::device::device_type>();
+ switch (DeviceType) {
+ case sycl::info::device_type::cpu:
+ return "cpu";
+ case sycl::info::device_type::gpu:
+ return "gpu";
+ case sycl::info::device_type::host:
+ return "host";
+ case sycl::info::device_type::accelerator:
+ return "acc";
+ default:
+ return "unknown";
+ }
+}
+
+inline std::string get_device_backend_and_type(const sycl::device &device) {
+ std::stringstream device_type;
+ sycl::backend backend = device.get_backend();
+ device_type << backend << ":" << get_device_type_name(device);
+ return device_type.str();
+}
+
+namespace dpct
+{
+ typedef sycl::queue *queue_ptr;
+ typedef sycl::event *event_ptr;
+ typedef char *device_ptr;
+ typedef uint8_t byte_t;
+ typedef sycl::buffer<byte_t> buffer_t;
+
+ /// SYCL default exception handler
+ inline auto exception_handler = [](sycl::exception_list exceptions)
+ {
+ for (std::exception_ptr const &e : exceptions)
+ {
+ try
+ {
+ std::rethrow_exception(e);
+ }
+ catch (sycl::exception const &e)
+ {
+ std::cerr << "Caught asynchronous SYCL exception:" << std::endl
+ << e.what() << std::endl
+ << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ }
+ }
+ };
+
+ enum error_code
+ {
+ success = 0,
+ default_error = 999
+ };
+
+ enum memcpy_direction
+ {
+ host_to_host,
+ host_to_device,
+ device_to_host,
+ device_to_device,
+ automatic
+ };
+
+ enum memory_region
+ {
+ global = 0, // device global memory
+ constant, // device constant memory
+ local, // device local memory
+ shared, // memory which can be accessed by host and device
+ };
+
+ enum class library_data_t : unsigned char
+ {
+ real_float = 0,
+ complex_float,
+ real_double,
+ complex_double,
+ real_half,
+ complex_half,
+ real_bfloat16,
+ complex_bfloat16,
+ real_int4,
+ complex_int4,
+ real_uint4,
+ complex_uint4,
+ real_int8,
+ complex_int8,
+ real_uint8,
+ complex_uint8,
+ real_int16,
+ complex_int16,
+ real_uint16,
+ complex_uint16,
+ real_int32,
+ complex_int32,
+ real_uint32,
+ complex_uint32,
+ real_int64,
+ complex_int64,
+ real_uint64,
+ complex_uint64,
+ real_int8_4,
+ real_int8_32,
+ real_uint8_4,
+ library_data_t_size
+ };
+
+ template <typename T>
+ struct DataType
+ {
+ using T2 = T;
+ };
+ template <typename T>
+ struct DataType<sycl::vec<T, 2>>
+ {
+ using T2 = std::complex<T>;
+ };
+
+ static void destroy_event(event_ptr event)
+ {
+ delete event;
+ }
+
+ static inline unsigned int get_tid()
+ {
+#if defined(__linux__)
+ return syscall(SYS_gettid);
+#elif defined(_WIN64)
+ return GetCurrentThreadId();
+#else
+#error "Only support Windows and Linux."
+#endif
+ }
+
+ namespace detail
+ {
+ static void get_version(const sycl::device &dev, int &major, int &minor)
+ {
+ // Version string has the following format:
+ // a. OpenCL<space><major.minor><space><vendor-specific-information>
+ // b. <major.minor>
+ // c. <AmdGcnArchName> e.g gfx1030
+ std::string ver;
+ ver = dev.get_info<sycl::info::device::version>();
+ std::string::size_type i = 0;
+ while (i < ver.size()) {
+ if (isdigit(ver[i]))
+ break;
+ i++;
+ }
+ major = std::stoi(&(ver[i]));
+ while (i < ver.size()) {
+ if (ver[i] == '.')
+ break;
+ i++;
+ }
+ if (i < ver.size()) {
+ // a. and b.
+ i++;
+ minor = std::stoi(&(ver[i]));
+ } else {
+ // c.
+ minor = 0;
+ }
+ }
+
+ template <typename tag, typename T>
+ class generic_error_type
+ {
+ public:
+ generic_error_type() = default;
+ generic_error_type(T value) : value{value} {}
+ operator T() const { return value; }
+
+ private:
+ T value;
+ };
+
+ } // namespace detail
+
+ /// Pitched 2D/3D memory data.
+ class pitched_data
+ {
+ public:
+ pitched_data() : pitched_data(nullptr, 0, 0, 0) {}
+ pitched_data(void *data, size_t pitch, size_t x, size_t y)
+ : _data(data), _pitch(pitch), _x(x), _y(y) {}
+
+ void *get_data_ptr() { return _data; }
+ void set_data_ptr(void *data) { _data = data; }
+
+ size_t get_pitch() { return _pitch; }
+ void set_pitch(size_t pitch) { _pitch = pitch; }
+
+ size_t get_x() { return _x; }
+ void set_x(size_t x) { _x = x; }
+
+ size_t get_y() { return _y; }
+ void set_y(size_t y) { _y = y; }
+
+ private:
+ void *_data;
+ size_t _pitch, _x, _y;
+ };
+
+ class device_info
+ {
+ public:
+ // get interface
+ const char *get_name() const { return _name; }
+ char *get_name() { return _name; }
+ template <typename WorkItemSizesTy = sycl::range<3>,
+ std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||
+ std::is_same_v<WorkItemSizesTy, int *>,
+ int> = 0>
+ auto get_max_work_item_sizes() const
+ {
+ if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)
+ return sycl::range<3>(_max_work_item_sizes_i[0],
+ _max_work_item_sizes_i[1],
+ _max_work_item_sizes_i[2]);
+ else
+ {
+ return _max_work_item_sizes_i;
+ }
+ }
+ template <typename WorkItemSizesTy = sycl::range<3>,
+ std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||
+ std::is_same_v<WorkItemSizesTy, int *>,
+ int> = 0>
+ auto get_max_work_item_sizes()
+ {
+ if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)
+ return sycl::range<3>(_max_work_item_sizes_i[0],
+ _max_work_item_sizes_i[1],
+ _max_work_item_sizes_i[2]);
+ else
+ {
+ return _max_work_item_sizes_i;
+ }
+ }
+ bool get_host_unified_memory() const { return _host_unified_memory; }
+ int get_major_version() const { return _major; }
+ int get_minor_version() const { return _minor; }
+ int get_integrated() const { return _integrated; }
+ int get_max_clock_frequency() const { return _frequency; }
+ int get_max_compute_units() const { return _max_compute_units; }
+ int get_max_work_group_size() const { return _max_work_group_size; }
+ int get_max_sub_group_size() const { return _max_sub_group_size; }
+ int get_max_work_items_per_compute_unit() const
+ {
+ return _max_work_items_per_compute_unit;
+ }
+ int get_max_register_size_per_work_group() const
+ {
+ return _max_register_size_per_work_group;
+ }
+ template <typename NDRangeSizeTy = size_t *,
+ std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||
+ std::is_same_v<NDRangeSizeTy, int *>,
+ int> = 0>
+ auto get_max_nd_range_size() const
+ {
+ if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)
+ return _max_nd_range_size;
+ else
+ return _max_nd_range_size_i;
+ }
+ template <typename NDRangeSizeTy = size_t *,
+ std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||
+ std::is_same_v<NDRangeSizeTy, int *>,
+ int> = 0>
+ auto get_max_nd_range_size()
+ {
+ if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)
+ return _max_nd_range_size;
+ else
+ return _max_nd_range_size_i;
+ }
+ size_t get_global_mem_size() const { return _global_mem_size; }
+ size_t get_local_mem_size() const { return _local_mem_size; }
+ size_t get_max_mem_alloc_size() const { return _max_mem_alloc_size; }
+ /// Returns the maximum clock rate of device's global memory in kHz. If
+ /// compiler does not support this API then returns default value 3200000 kHz.
+ unsigned int get_memory_clock_rate() const { return _memory_clock_rate; }
+ /// Returns the maximum bus width between device and memory in bits. If
+ /// compiler does not support this API then returns default value 64 bits.
+ unsigned int get_memory_bus_width() const { return _memory_bus_width; }
+ uint32_t get_device_id() const { return _device_id; }
+ std::array<unsigned char, 16> get_uuid() const { return _uuid; }
+ /// Returns global memory cache size in bytes.
+ unsigned int get_global_mem_cache_size() const
+ {
+ return _global_mem_cache_size;
+ }
+
+ // set interface
+ void set_name(const char *name)
+ {
+ size_t length = strlen(name);
+ if (length < 256)
+ {
+ std::memcpy(_name, name, length + 1);
+ }
+ else
+ {
+ std::memcpy(_name, name, 255);
+ _name[255] = '\0';
+ }
+ }
+ void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes)
+ {
+ for (int i = 0; i < 3; ++i)
+ _max_work_item_sizes_i[i] = max_work_item_sizes[i];
+ }
+ [[deprecated]] void
+ set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes)
+ {
+ for (int i = 0; i < 3; ++i)
+ {
+ _max_work_item_sizes_i[i] = max_work_item_sizes[i];
+ }
+ }
+ void set_host_unified_memory(bool host_unified_memory)
+ {
+ _host_unified_memory = host_unified_memory;
+ }
+ void set_major_version(int major) { _major = major; }
+ void set_minor_version(int minor) { _minor = minor; }
+ void set_integrated(int integrated) { _integrated = integrated; }
+ void set_max_clock_frequency(int frequency) { _frequency = frequency; }
+ void set_max_compute_units(int max_compute_units)
+ {
+ _max_compute_units = max_compute_units;
+ }
+ void set_global_mem_size(size_t global_mem_size)
+ {
+ _global_mem_size = global_mem_size;
+ }
+ void set_local_mem_size(size_t local_mem_size)
+ {
+ _local_mem_size = local_mem_size;
+ }
+ void set_max_mem_alloc_size(size_t max_mem_alloc_size)
+ {
+ _max_mem_alloc_size = max_mem_alloc_size;
+ }
+ void set_max_work_group_size(int max_work_group_size)
+ {
+ _max_work_group_size = max_work_group_size;
+ }
+ void set_max_sub_group_size(int max_sub_group_size)
+ {
+ _max_sub_group_size = max_sub_group_size;
+ }
+ void
+ set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit)
+ {
+ _max_work_items_per_compute_unit = max_work_items_per_compute_unit;
+ }
+ void set_max_nd_range_size(int max_nd_range_size[])
+ {
+ for (int i = 0; i < 3; i++)
+ {
+ _max_nd_range_size[i] = max_nd_range_size[i];
+ _max_nd_range_size_i[i] = max_nd_range_size[i];
+ }
+ }
+ void set_memory_clock_rate(unsigned int memory_clock_rate)
+ {
+ _memory_clock_rate = memory_clock_rate;
+ }
+ void set_memory_bus_width(unsigned int memory_bus_width)
+ {
+ _memory_bus_width = memory_bus_width;
+ }
+ void
+ set_max_register_size_per_work_group(int max_register_size_per_work_group)
+ {
+ _max_register_size_per_work_group = max_register_size_per_work_group;
+ }
+ void set_device_id(uint32_t device_id)
+ {
+ _device_id = device_id;
+ }
+ void set_uuid(std::array<unsigned char, 16> uuid)
+ {
+ _uuid = std::move(uuid);
+ }
+ void set_global_mem_cache_size(unsigned int global_mem_cache_size)
+ {
+ _global_mem_cache_size = global_mem_cache_size;
+ }
+
+ private:
+ char _name[256];
+ int _max_work_item_sizes_i[3];
+ bool _host_unified_memory = false;
+ int _major;
+ int _minor;
+ int _integrated = 0;
+ int _frequency;
+ // Set estimated value 3200000 kHz as default value.
+ unsigned int _memory_clock_rate = 3200000;
+ // Set estimated value 64 bits as default value.
+ unsigned int _memory_bus_width = 64;
+ unsigned int _global_mem_cache_size;
+ int _max_compute_units;
+ int _max_work_group_size;
+ int _max_sub_group_size;
+ int _max_work_items_per_compute_unit;
+ int _max_register_size_per_work_group;
+ size_t _global_mem_size;
+ size_t _local_mem_size;
+ size_t _max_mem_alloc_size;
+ size_t _max_nd_range_size[3];
+ int _max_nd_range_size_i[3];
+ uint32_t _device_id;
+ std::array<unsigned char, 16> _uuid;
+ };
+
+ static int get_major_version(const sycl::device &dev)
+ {
+ int major, minor;
+ detail::get_version(dev, major, minor);
+ return major;
+ }
+
+ static int get_minor_version(const sycl::device &dev)
+ {
+ int major, minor;
+ detail::get_version(dev, major, minor);
+ return minor;
+ }
+
+ static void get_device_info(device_info &out, const sycl::device &dev)
+ {
+ device_info prop;
+ prop.set_name(dev.get_info<sycl::info::device::name>().c_str());
+
+ int major, minor;
+ detail::get_version(dev, major, minor);
+ prop.set_major_version(major);
+ prop.set_minor_version(minor);
+
+ prop.set_max_work_item_sizes(
+#if (__SYCL_COMPILER_VERSION && __SYCL_COMPILER_VERSION < 20220902)
+ // oneAPI DPC++ compiler older than 2022/09/02, where max_work_item_sizes
+ // is an enum class element
+ dev.get_info<sycl::info::device::max_work_item_sizes>());
+#else
+ // SYCL 2020-conformant code, max_work_item_sizes is a struct templated by
+ // an int
+ dev.get_info<sycl::info::device::max_work_item_sizes<3>>());
+#endif
+ prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations));
+
+ prop.set_max_clock_frequency(
+ dev.get_info<sycl::info::device::max_clock_frequency>() * 1000);
+
+ prop.set_max_compute_units(
+ dev.get_info<sycl::info::device::max_compute_units>());
+ prop.set_max_work_group_size(
+ dev.get_info<sycl::info::device::max_work_group_size>());
+ prop.set_global_mem_size(dev.get_info<sycl::info::device::global_mem_size>());
+ prop.set_local_mem_size(dev.get_info<sycl::info::device::local_mem_size>());
+ prop.set_max_mem_alloc_size(dev.get_info<sycl::info::device::max_mem_alloc_size>());
+
+#if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6)
+ if (dev.has(sycl::aspect::ext_intel_memory_clock_rate))
+ {
+ unsigned int tmp =
+ dev.get_info<sycl::ext::intel::info::device::memory_clock_rate>();
+ if (tmp != 0)
+ prop.set_memory_clock_rate(1000 * tmp);
+ }
+ if (dev.has(sycl::aspect::ext_intel_memory_bus_width))
+ {
+ prop.set_memory_bus_width(
+ dev.get_info<sycl::ext::intel::info::device::memory_bus_width>());
+ }
+ if (dev.has(sycl::aspect::ext_intel_device_id))
+ {
+ prop.set_device_id(
+ dev.get_info<sycl::ext::intel::info::device::device_id>());
+ }
+ if (dev.has(sycl::aspect::ext_intel_device_info_uuid))
+ {
+ prop.set_uuid(dev.get_info<sycl::ext::intel::info::device::uuid>());
+ }
+#elif defined(_MSC_VER) && !defined(__clang__)
+#pragma message("get_device_info: querying memory_clock_rate and \
+ memory_bus_width are not supported by the compiler used. \
+ Use 3200000 kHz as memory_clock_rate default value. \
+ Use 64 bits as memory_bus_width default value.")
+#else
+#warning "get_device_info: querying memory_clock_rate and \
+ memory_bus_width are not supported by the compiler used. \
+ Use 3200000 kHz as memory_clock_rate default value. \
+ Use 64 bits as memory_bus_width default value."
+#endif
+
+ size_t max_sub_group_size = 1;
+ std::vector<size_t> sub_group_sizes =
+ dev.get_info<sycl::info::device::sub_group_sizes>();
+
+ for (const auto &sub_group_size : sub_group_sizes)
+ {
+ if (max_sub_group_size < sub_group_size)
+ max_sub_group_size = sub_group_size;
+ }
+
+ prop.set_max_sub_group_size(max_sub_group_size);
+
+ prop.set_max_work_items_per_compute_unit(
+ dev.get_info<sycl::info::device::max_work_group_size>());
+ int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF};
+ prop.set_max_nd_range_size(max_nd_range_size);
+
+ // Estimates max register size per work group, feel free to update the value
+ // according to device properties.
+ prop.set_max_register_size_per_work_group(65536);
+
+ prop.set_global_mem_cache_size(
+ dev.get_info<sycl::info::device::global_mem_cache_size>());
+ out = prop;
+ }
+
+ /// dpct device extension
+ class device_ext : public sycl::device {
+ typedef std::mutex mutex_type;
+
+ public:
+ device_ext() : sycl::device() {}
+ ~device_ext() {
+ std::lock_guard<mutex_type> lock(m_mutex);
+ clear_queues();
+ }
+ device_ext(const sycl::device &base) : sycl::device(base) {
+ std::lock_guard<mutex_type> lock(m_mutex);
+ init_queues();
+ }
+
+ int is_native_atomic_supported() { return 0; }
+ int get_major_version() const { return dpct::get_major_version(*this); }
+
+ int get_minor_version() const { return dpct::get_minor_version(*this); }
+
+ int get_max_compute_units() const {
+ return get_device_info().get_max_compute_units();
+ }
+
+ /// Return the maximum clock frequency of this device in KHz.
+ int get_max_clock_frequency() const {
+ return get_device_info().get_max_clock_frequency();
+ }
+
+ int get_integrated() const { return get_device_info().get_integrated(); }
+
+ int get_max_sub_group_size() const {
+ return get_device_info().get_max_sub_group_size();
+ }
+
+ int get_max_register_size_per_work_group() const {
+ return get_device_info().get_max_register_size_per_work_group();
+ }
+
+ int get_max_work_group_size() const {
+ return get_device_info().get_max_work_group_size();
+ }
+
+ int get_mem_base_addr_align() const {
+ return get_info<sycl::info::device::mem_base_addr_align>();
+ }
+
+ size_t get_global_mem_size() const {
+ return get_device_info().get_global_mem_size();
+ }
+
+ size_t get_max_mem_alloc_size() const {
+ return get_device_info().get_max_mem_alloc_size();
+ }
+
+ /// Get the number of bytes of free and total memory on the SYCL device.
+ /// \param [out] free_memory The number of bytes of free memory on the
+ /// SYCL device. \param [out] total_memory The number of bytes of total
+ /// memory on the SYCL device.
+ void get_memory_info(size_t &free_memory, size_t &total_memory) {
+ total_memory = get_device_info().get_global_mem_size();
+ const char *warning_info =
+ "get_memory_info: [warning] ext_intel_free_memory is not "
+ "supported (export/set ZES_ENABLE_SYSMAN=1 to support), "
+ "use total memory as free memory";
+#if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105)
+ if (!has(sycl::aspect::ext_intel_free_memory)) {
+ std::cerr << warning_info << std::endl;
+ free_memory = total_memory;
+ } else {
+ free_memory = get_info<sycl::ext::intel::info::device::free_memory>();
+ }
+#else
+ std::cerr << warning_info << std::endl;
+ free_memory = total_memory;
+#if defined(_MSC_VER) && !defined(__clang__)
+#pragma message("Querying the number of bytes of free memory is not supported")
+#else
+#warning "Querying the number of bytes of free memory is not supported"
+#endif
+#endif
+ }
+
+ void get_device_info(device_info &out) const {
+ dpct::get_device_info(out, *this);
+ }
+
+ device_info get_device_info() const {
+ device_info prop;
+ dpct::get_device_info(prop, *this);
+ return prop;
+ }
+
+ void reset() {
+ std::lock_guard<mutex_type> lock(m_mutex);
+ clear_queues();
+ init_queues();
+ }
+
+ sycl::queue &in_order_queue() { return _q_in_order; }
+
+ sycl::queue &out_of_order_queue() { return _q_out_of_order; }
+
+ sycl::queue &default_queue() { return in_order_queue(); }
+
+ void queues_wait_and_throw() {
+ std::unique_lock<mutex_type> lock(m_mutex);
+ lock.unlock();
+ for (auto &q : _queues) {
+ q.wait_and_throw();
+ }
+ // Guard the destruct of current_queues to make sure the ref count is
+ // safe.
+ lock.lock();
+ }
+
+ sycl::queue create_queue(bool enable_exception_handler = false) {
+ return create_in_order_queue(enable_exception_handler);
+ }
+
+ sycl::queue create_queue(sycl::device device,
+ bool enable_exception_handler = false) {
+ return create_in_order_queue(device, enable_exception_handler);
+ }
+
+ sycl::queue create_in_order_queue(bool enable_exception_handler = false) {
+ std::lock_guard<mutex_type> lock(m_mutex);
+ return create_queue_impl(enable_exception_handler,
+ sycl::property::queue::in_order());
+ }
+
+ sycl::queue create_in_order_queue(sycl::device device,
+ bool enable_exception_handler = false) {
+ std::lock_guard<mutex_type> lock(m_mutex);
+ return create_queue_impl(device, enable_exception_handler,
+ sycl::property::queue::in_order());
+ }
+
+ sycl::queue create_out_of_order_queue(
+ bool enable_exception_handler = false) {
+ std::lock_guard<mutex_type> lock(m_mutex);
+ return create_queue_impl(enable_exception_handler);
+ }
+
+ void destroy_queue(sycl::queue queue) {
+ std::lock_guard<mutex_type> lock(m_mutex);
+ _queues.erase(std::remove_if(_queues.begin(), _queues.end(),
+ [=](const sycl::queue &q) -> bool
+ {
+ return q == queue;
+ }),
+ _queues.end());
+ }
+ void set_saved_queue(sycl::queue q) {
+ std::lock_guard<mutex_type> lock(m_mutex);
+ _saved_queue = q;
+ }
+ sycl::queue get_saved_queue() const {
+ std::lock_guard<mutex_type> lock(m_mutex);
+ return _saved_queue;
+ }
+
+ private:
+ void clear_queues() { _queues.clear(); }
+
+ void init_queues() {
+ _q_in_order =
+ create_queue_impl(true, sycl::property::queue::in_order());
+ _q_out_of_order = create_queue_impl(true);
+ _saved_queue = default_queue();
+ }
+
+ /// Caller should acquire resource \p m_mutex before calling this
+ /// function.
+ template <class... Properties>
+ sycl::queue create_queue_impl(bool enable_exception_handler,
+ Properties... properties) {
+ sycl::async_handler eh = {};
+ if (enable_exception_handler) {
+ eh = exception_handler;
+ }
+ _queues.push_back(sycl::queue(
+ *this, eh,
+ sycl::property_list(
+#ifdef DPCT_PROFILING_ENABLED
+ sycl::property::queue::enable_profiling(),
+#endif
+ properties...)));
+
+ return _queues.back();
+ }
+
+ template <class... Properties>
+ sycl::queue create_queue_impl(sycl::device device,
+ bool enable_exception_handler,
+ Properties... properties) {
+ sycl::async_handler eh = {};
+ if (enable_exception_handler) {
+ eh = exception_handler;
+ }
+ _queues.push_back(sycl::queue(
+ device, eh,
+ sycl::property_list(
+#ifdef DPCT_PROFILING_ENABLED
+ sycl::property::queue::enable_profiling(),
+#endif
+ properties...)));
+
+ return _queues.back();
+ }
+
+ void get_version(int &major, int &minor) const {
+ detail::get_version(*this, major, minor);
+ }
+ sycl::queue _q_in_order, _q_out_of_order;
+ sycl::queue _saved_queue;
+ std::vector<sycl::queue> _queues;
+ mutable mutex_type m_mutex;
+ };
+
+
+ /// device manager
+ class dev_mgr
+ {
+ public:
+ device_ext &current_device()
+ {
+ unsigned int dev_id = current_device_id();
+ check_id(dev_id);
+ return *_devs[dev_id];
+ }
+ device_ext &cpu_device() const
+ {
+ std::lock_guard<std::recursive_mutex> lock(m_mutex);
+ if (_cpu_device == -1)
+ {
+ throw std::runtime_error("no valid cpu device");
+ }
+ else
+ {
+ return *_devs[_cpu_device];
+ }
+ }
+ device_ext &get_device(unsigned int id) const
+ {
+ std::lock_guard<std::recursive_mutex> lock(m_mutex);
+ check_id(id);
+ return *_devs[id];
+ }
+ unsigned int current_device_id() const
+ {
+ std::lock_guard<std::recursive_mutex> lock(m_mutex);
+ auto it = _thread2dev_map.find(get_tid());
+ if (it != _thread2dev_map.end())
+ return it->second;
+ return DEFAULT_DEVICE_ID;
+ }
+
+ /// Select device with a device ID.
+ /// \param [in] id The id of the device which can
+ /// be obtained through get_device_id(const sycl::device).
+ void select_device(unsigned int id)
+ {
+ std::lock_guard<std::recursive_mutex> lock(m_mutex);
+ check_id(id);
+ _thread2dev_map[get_tid()] = id;
+ }
+ unsigned int device_count() { return _devs.size(); }
+
+ unsigned int get_device_id(const sycl::device &dev)
+ {
+ unsigned int id = 0;
+ for (auto &dev_item : _devs)
+ {
+ if (*dev_item == dev)
+ {
+ return id;
+ }
+ id++;
+ }
+ return -1;
+ }
+
+ inline std::string get_preferred_gpu_platform_name() {
+ std::string result;
+
+ std::string filter = "level-zero";
+ char* env = getenv("ONEAPI_DEVICE_SELECTOR");
+ if (env) {
+ if (std::strstr(env, "level_zero")) {
+ filter = "level-zero";
+ }
+ else if (std::strstr(env, "opencl")) {
+ filter = "opencl";
+ }
+ else if (std::strstr(env, "cuda")) {
+ filter = "cuda";
+ }
+ else if (std::strstr(env, "hip")) {
+ filter = "hip";
+ }
+ else {
+ throw std::runtime_error("invalid device filter: " + std::string(env));
+ }
+ }
+
+ auto plaform_list = sycl::platform::get_platforms();
+
+ for (const auto& platform : plaform_list) {
+ auto devices = platform.get_devices();
+ auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
+ return d.is_gpu();
+ });
+
+ if (gpu_dev == devices.end()) {
+ // cout << "platform [" << platform_name
+ // << "] does not contain GPU devices, skipping\n";
+ continue;
+ }
+
+ auto platform_name = platform.get_info<sycl::info::platform::name>();
+ std::string platform_name_low_case;
+ platform_name_low_case.resize(platform_name.size());
+
+ std::transform(
+ platform_name.begin(), platform_name.end(), platform_name_low_case.begin(), ::tolower);
+
+ if (platform_name_low_case.find(filter) == std::string::npos) {
+ // cout << "platform [" << platform_name
+ // << "] does not match with requested "
+ // << filter << ", skipping\n";
+ continue;
+ }
+
+ result = platform_name;
+ }
+
+ if (result.empty())
+ throw std::runtime_error("can not find preferred GPU platform");
+
+ return result;
+ }
+
+ template <class DeviceSelector>
+ std::enable_if_t<
+ std::is_invocable_r_v<int, DeviceSelector, const sycl::device &>>
+ select_device(const DeviceSelector &selector = sycl::gpu_selector_v)
+ {
+ sycl::device selected_device = sycl::device(selector);
+ unsigned int selected_device_id = get_device_id(selected_device);
+ select_device(selected_device_id);
+ }
+
+ /// Returns the instance of device manager singleton.
+ static dev_mgr &instance()
+ {
+ static dev_mgr d_m;
+ return d_m;
+ }
+ dev_mgr(const dev_mgr &) = delete;
+ dev_mgr &operator=(const dev_mgr &) = delete;
+ dev_mgr(dev_mgr &&) = delete;
+ dev_mgr &operator=(dev_mgr &&) = delete;
+
+ private:
+ mutable std::recursive_mutex m_mutex;
+ static bool compare_dev(sycl::device &device1, sycl::device &device2)
+ {
+ sycl::backend backend1 = device1.get_backend();
+ sycl::backend backend2 = device2.get_backend();
+ // levelzero backends always come first
+ if(backend1 == sycl::backend::ext_oneapi_level_zero && backend2 != sycl::backend::ext_oneapi_level_zero) return true;
+ if(backend1 != sycl::backend::ext_oneapi_level_zero && backend2 == sycl::backend::ext_oneapi_level_zero) return false;
+ dpct::device_info prop1;
+ dpct::get_device_info(prop1, device1);
+ dpct::device_info prop2;
+ dpct::get_device_info(prop2, device2);
+ return prop1.get_max_compute_units() > prop2.get_max_compute_units();
+ }
+ static int convert_backend_index(std::string & backend) {
+ if (backend == "ext_oneapi_level_zero:gpu") return 0;
+ if (backend == "opencl:gpu") return 1;
+ if (backend == "ext_oneapi_cuda:gpu") return 2;
+ if (backend == "ext_oneapi_hip:gpu") return 3;
+ if (backend == "opencl:cpu") return 4;
+ if (backend == "opencl:acc") return 5;
+ printf("convert_backend_index: can't handle backend=%s\n", backend.c_str());
+ GGML_ASSERT(false);
+ }
+ static bool compare_backend(std::string &backend1, std::string &backend2) {
+ return convert_backend_index(backend1) < convert_backend_index(backend2);
+ }
+ dev_mgr()
+ {
+ sycl::device default_device =
+ sycl::device(sycl::default_selector_v);
+ _devs.push_back(std::make_shared<device_ext>(default_device));
+
+ std::vector<sycl::device> sycl_all_devs;
+ // Collect other devices except for the default device.
+ if (default_device.is_cpu())
+ _cpu_device = 0;
+
+ auto Platforms = sycl::platform::get_platforms();
+ // Keep track of the number of devices per backend
+ std::map<sycl::backend, size_t> DeviceNums;
+ std::map<std::string, std::vector<sycl::device>> backend_devices;
+ auto preferred_platform_name = get_preferred_gpu_platform_name();
+
+ while (!Platforms.empty()) {
+ auto Platform = Platforms.back();
+ Platforms.pop_back();
+ auto platform_name = Platform.get_info<sycl::info::platform::name>();
+ if (platform_name.compare(preferred_platform_name) != 0) {
+ continue;
+ }
+ auto devices = Platform.get_devices();
+ std::string backend_type = get_device_backend_and_type(devices[0]);
+ for (const auto &device : devices) {
+ backend_devices[backend_type].push_back(device);
+ }
+ }
+
+ std::vector<std::string> keys;
+ for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) {
+ keys.push_back(it->first);
+ }
+ std::sort(keys.begin(), keys.end(), compare_backend);
+
+ for (auto &key : keys) {
+ std::vector<sycl::device> devs = backend_devices[key];
+ std::sort(devs.begin(), devs.end(), compare_dev);
+ for (const auto &dev : devs) {
+ sycl_all_devs.push_back(dev);
+ }
+ }
+
+ for (auto &dev : sycl_all_devs)
+ {
+ if (dev == default_device)
+ {
+ continue;
+ }
+ _devs.push_back(std::make_shared<device_ext>(dev));
+ if (_cpu_device == -1 && dev.is_cpu())
+ {
+ _cpu_device = _devs.size() - 1;
+ }
+ }
+ }
+ void check_id(unsigned int id) const
+ {
+ if (id >= _devs.size())
+ {
+ throw std::runtime_error("invalid device id");
+ }
+ }
+ std::vector<std::shared_ptr<device_ext>> _devs;
+ /// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current
+ /// thread id in _thread2dev_map, which means default device should be used
+ /// for the current thread.
+ const unsigned int DEFAULT_DEVICE_ID = 0;
+ /// thread-id to device-id map.
+ std::map<unsigned int, unsigned int> _thread2dev_map;
+ int _cpu_device = -1;
+ };
+
+ static inline sycl::queue &get_default_queue()
+ {
+ return dev_mgr::instance().current_device().default_queue();
+ }
+
+ namespace detail
+ {
+ enum class pointer_access_attribute
+ {
+ host_only = 0,
+ device_only,
+ host_device,
+ end
+ };
+
+ static pointer_access_attribute get_pointer_attribute(sycl::queue &q,
+ const void *ptr)
+ {
+ switch (sycl::get_pointer_type(ptr, q.get_context()))
+ {
+ case sycl::usm::alloc::unknown:
+ return pointer_access_attribute::host_only;
+ case sycl::usm::alloc::device:
+ return pointer_access_attribute::device_only;
+ case sycl::usm::alloc::shared:
+ case sycl::usm::alloc::host:
+ return pointer_access_attribute::host_device;
+ }
+ }
+
+ template <typename ArgT>
+ inline constexpr std::uint64_t get_type_combination_id(ArgT Val)
+ {
+ static_assert((unsigned char)library_data_t::library_data_t_size <=
+ std::numeric_limits<unsigned char>::max() &&
+ "library_data_t size exceeds limit.");
+ static_assert(std::is_same_v<ArgT, library_data_t>, "Unsupported ArgT");
+ return (std::uint64_t)Val;
+ }
+
+ template <typename FirstT, typename... RestT>
+ inline constexpr std::uint64_t get_type_combination_id(FirstT FirstVal,
+ RestT... RestVal)
+ {
+ static_assert((std::uint8_t)library_data_t::library_data_t_size <=
+ std::numeric_limits<unsigned char>::max() &&
+ "library_data_t size exceeds limit.");
+ static_assert(sizeof...(RestT) <= 8 && "Too many parameters");
+ static_assert(std::is_same_v<FirstT, library_data_t>, "Unsupported FirstT");
+ return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal);
+ }
+
+ class mem_mgr
+ {
+ mem_mgr()
+ {
+ // Reserved address space, no real memory allocation happens here.
+#if defined(__linux__)
+ mapped_address_space =
+ (byte_t *)mmap(nullptr, mapped_region_size, PROT_NONE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+#elif defined(_WIN64)
+ mapped_address_space = (byte_t *)VirtualAlloc(
+ NULL, // NULL specified as the base address parameter
+ mapped_region_size, // Size of allocation
+ MEM_RESERVE, // Allocate reserved pages
+ PAGE_NOACCESS); // Protection = no access
+#else
+#error "Only support Windows and Linux."
+#endif
+ next_free = mapped_address_space;
+ }
+
+ public:
+ using buffer_id_t = int;
+
+ struct allocation
+ {
+ buffer_t buffer;
+ byte_t *alloc_ptr;
+ size_t size;
+ };
+
+ ~mem_mgr()
+ {
+#if defined(__linux__)
+ munmap(mapped_address_space, mapped_region_size);
+#elif defined(_WIN64)
+ VirtualFree(mapped_address_space, 0, MEM_RELEASE);
+#else
+#error "Only support Windows and Linux."
+#endif
+ }
+
+ mem_mgr(const mem_mgr &) = delete;
+ mem_mgr &operator=(const mem_mgr &) = delete;
+ mem_mgr(mem_mgr &&) = delete;
+ mem_mgr &operator=(mem_mgr &&) = delete;
+
+ /// Allocate
+ void *mem_alloc(size_t size)
+ {
+ if (!size)
+ return nullptr;
+ std::lock_guard<std::mutex> lock(m_mutex);
+ if (next_free + size > mapped_address_space + mapped_region_size)
+ {
+ throw std::runtime_error("dpct_malloc: out of memory for virtual memory pool");
+ }
+ // Allocation
+ sycl::range<1> r(size);
+ buffer_t buf(r);
+ allocation A{buf, next_free, size};
+ // Map allocation to device pointer
+ void *result = next_free;
+ m_map.emplace(next_free + size, A);
+ // Update pointer to the next free space.
+ next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1);
+
+ return result;
+ }
+
+ /// Deallocate
+ void mem_free(const void *ptr)
+ {
+ if (!ptr)
+ return;
+ std::lock_guard<std::mutex> lock(m_mutex);
+ auto it = get_map_iterator(ptr);
+ m_map.erase(it);
+ }
+
+ /// map: device pointer -> allocation(buffer, alloc_ptr, size)
+ allocation translate_ptr(const void *ptr)
+ {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ auto it = get_map_iterator(ptr);
+ return it->second;
+ }
+
+ /// Check if the pointer represents device pointer or not.
+ bool is_device_ptr(const void *ptr) const
+ {
+ std::lock_guard<std::mutex> lock(m_mutex);
+ return (mapped_address_space <= ptr) &&
+ (ptr < mapped_address_space + mapped_region_size);
+ }
+
+ /// Returns the instance of memory manager singleton.
+ static mem_mgr &instance()
+ {
+ static mem_mgr m;
+ return m;
+ }
+
+ private:
+ std::map<byte_t *, allocation> m_map;
+ mutable std::mutex m_mutex;
+ byte_t *mapped_address_space;
+ byte_t *next_free;
+ const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024;
+ const size_t alignment = 256;
+ /// This padding may be defined to some positive value to debug
+ /// out of bound accesses.
+ const size_t extra_padding = 0;
+
+ std::map<byte_t *, allocation>::iterator get_map_iterator(const void *ptr)
+ {
+ auto it = m_map.upper_bound((byte_t *)ptr);
+ if (it == m_map.end())
+ {
+ // Not a virtual pointer.
+ throw std::runtime_error("can not get buffer from non-virtual pointer");
+ }
+ const allocation &alloc = it->second;
+ if (ptr < alloc.alloc_ptr)
+ {
+ // Out of bound.
+ // This may happen if there's a gap between allocations due to alignment
+ // or extra padding and pointer points to this gap.
+ throw std::runtime_error("invalid virtual pointer");
+ }
+ return it;
+ }
+ };
+
+ template <class T, memory_region Memory, size_t Dimension>
+ class accessor;
+ template <memory_region Memory, class T = byte_t>
+ class memory_traits
+ {
+ public:
+ static constexpr sycl::access::target target =
+ sycl::access::target::device;
+ static constexpr sycl::access_mode mode =
+ (Memory == constant) ? sycl::access_mode::read
+ : sycl::access_mode::read_write;
+ static constexpr size_t type_size = sizeof(T);
+ using element_t =
+ typename std::conditional<Memory == constant, const T, T>::type;
+ using value_t = typename std::remove_cv<T>::type;
+ template <size_t Dimension = 1>
+ using accessor_t = typename std::conditional<
+ Memory == local, sycl::local_accessor<value_t, Dimension>,
+ sycl::accessor<T, Dimension, mode, target>>::type;
+ using pointer_t = T *;
+ };
+
+ static inline void *dpct_malloc(size_t size, sycl::queue &q)
+ {
+ return sycl::malloc_device(size, q.get_device(), q.get_context());
+ }
+
+#define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F))
+ static inline void *dpct_malloc(size_t &pitch, size_t x, size_t y, size_t z,
+ sycl::queue &q)
+ {
+ pitch = PITCH_DEFAULT_ALIGN(x);
+ return dpct_malloc(pitch * y * z, q);
+ }
+
+ /**
+ * @brief Sets \p value to the first \p size elements starting from \p dev_ptr in \p q.
+ * @tparam valueT The type of the element to be set.
+ * @param [in] q The queue in which the operation is done.
+ * @param [in] dev_ptr Pointer to the virtual device memory address.
+ * @param [in] value The value to be set.
+ * @param [in] size Number of elements to be set to the value.
+ * @return An event representing the memset operation.
+ */
+ template <typename valueT>
+ static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr,
+ valueT value, size_t size)
+ {
+ return q.fill(dev_ptr, value, size);
+ }
+
+ /**
+ * @brief Sets \p value to the 3D memory region pointed by \p data in \p q.
+ * @tparam valueT The type of the element to be set.
+ * @param [in] q The queue in which the operation is done.
+ * @param [in] data Pointer to the pitched device memory region.
+ * @param [in] value The value to be set.
+ * @param [in] size 3D memory region by number of elements.
+ * @return An event list representing the memset operations.
+ */
+ template <typename valueT>
+ static inline std::vector<sycl::event>
+ dpct_memset(sycl::queue &q, pitched_data data, valueT value,
+ sycl::range<3> size)
+ {
+ std::vector<sycl::event> event_list;
+ size_t slice = data.get_pitch() * data.get_y();
+ unsigned char *data_surface = (unsigned char *)data.get_data_ptr();
+ for (size_t z = 0; z < size.get(2); ++z)
+ {
+ unsigned char *data_ptr = data_surface;
+ for (size_t y = 0; y < size.get(1); ++y)
+ {
+ event_list.push_back(dpct_memset(q, data_ptr, value, size.get(0)));
+ data_ptr += data.get_pitch();
+ }
+ data_surface += slice;
+ }
+ return event_list;
+ }
+
+ /**
+ * @brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p q.
+ * @tparam valueT The type of the element to be set.
+ * @param [in] q The queue in which the operation is done.
+ * @param [in] ptr Pointer to the virtual device memory.
+ * @param [in] pitch The pitch size by number of elements, including padding.
+ * @param [in] val The value to be set.
+ * @param [in] x The width of memory region by number of elements.
+ * @param [in] y The height of memory region by number of elements.
+ * @return An event list representing the memset operations.
+ */
+ template <typename valueT>
+ static inline std::vector<sycl::event>
+ dpct_memset(sycl::queue &q, void *ptr, size_t pitch, valueT val, size_t x,
+ size_t y)
+ {
+ return dpct_memset(q, pitched_data(ptr, pitch, x, 1), val,
+ sycl::range<3>(x, y, 1));
+ }
+
+ static memcpy_direction deduce_memcpy_direction(sycl::queue &q, void *to_ptr,
+ const void *from_ptr,
+ memcpy_direction dir)
+ {
+ switch (dir)
+ {
+ case memcpy_direction::host_to_host:
+ case memcpy_direction::host_to_device:
+ case memcpy_direction::device_to_host:
+ case memcpy_direction::device_to_device:
+ return dir;
+ case memcpy_direction::automatic:
+ {
+ // table[to_attribute][from_attribute]
+ static const memcpy_direction
+ direction_table[static_cast<unsigned>(pointer_access_attribute::end)]
+ [static_cast<unsigned>(pointer_access_attribute::end)] =
+ {{memcpy_direction::host_to_host,
+ memcpy_direction::device_to_host,
+ memcpy_direction::host_to_host},
+ {memcpy_direction::host_to_device,
+ memcpy_direction::device_to_device,
+ memcpy_direction::device_to_device},
+ {memcpy_direction::host_to_host,
+ memcpy_direction::device_to_device,
+ memcpy_direction::device_to_device}};
+ return direction_table[static_cast<unsigned>(get_pointer_attribute(
+ q, to_ptr))][static_cast<unsigned>(get_pointer_attribute(q, from_ptr))];
+ }
+ default:
+ throw std::runtime_error("dpct_memcpy: invalid direction value");
+ }
+ }
+
+ static sycl::event
+ dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size,
+ memcpy_direction direction,
+ const std::vector<sycl::event> &dep_events = {})
+ {
+ if (!size)
+ return sycl::event{};
+ return q.memcpy(to_ptr, from_ptr, size, dep_events);
+ GGML_UNUSED(direction);
+ }
+
+ // Get actual copy range and make sure it will not exceed range.
+ static inline size_t get_copy_range(sycl::range<3> size, size_t slice,
+ size_t pitch)
+ {
+ return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0);
+ }
+
+ static inline size_t get_offset(sycl::id<3> id, size_t slice,
+ size_t pitch)
+ {
+ return slice * id.get(2) + pitch * id.get(1) + id.get(0);
+ }
+
+ /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr
+ /// and \p from_range to another specified by \p to_ptr and \p to_range.
+ static inline std::vector<sycl::event>
+ dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
+ sycl::range<3> to_range, sycl::range<3> from_range,
+ sycl::id<3> to_id, sycl::id<3> from_id,
+ sycl::range<3> size, memcpy_direction direction,
+ const std::vector<sycl::event> &dep_events = {})
+ {
+ // RAII for host pointer
+ class host_buffer
+ {
+ void *_buf;
+ size_t _size;
+ sycl::queue &_q;
+ const std::vector<sycl::event> &_deps; // free operation depends
+
+ public:
+ host_buffer(size_t size, sycl::queue &q,
+ const std::vector<sycl::event> &deps)
+ : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {}
+ void *get_ptr() const { return _buf; }
+ size_t get_size() const { return _size; }
+ ~host_buffer()
+ {
+ if (_buf)
+ {
+ _q.submit([&](sycl::handler &cgh)
+ {
+ cgh.depends_on(_deps);
+ cgh.host_task([buf = _buf] { std::free(buf); }); });
+ }
+ }
+ };
+ std::vector<sycl::event> event_list;
+
+ size_t to_slice = to_range.get(1) * to_range.get(0),
+ from_slice = from_range.get(1) * from_range.get(0);
+ unsigned char *to_surface =
+ (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0));
+ const unsigned char *from_surface =
+ (const unsigned char *)from_ptr +
+ get_offset(from_id, from_slice, from_range.get(0));
+
+ if (to_slice == from_slice && to_slice == size.get(1) * size.get(0))
+ {
+ return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2),
+ direction, dep_events)};
+ }
+ direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction);
+ size_t size_slice = size.get(1) * size.get(0);
+ switch (direction)
+ {
+ case host_to_host:
+ for (size_t z = 0; z < size.get(2); ++z)
+ {
+ unsigned char *to_ptr = to_surface;
+ const unsigned char *from_ptr = from_surface;
+ if (to_range.get(0) == from_range.get(0) &&
+ to_range.get(0) == size.get(0))
+ {
+ event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice,
+ direction, dep_events));
+ }
+ else
+ {
+ for (size_t y = 0; y < size.get(1); ++y)
+ {
+ event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0),
+ direction, dep_events));
+ to_ptr += to_range.get(0);
+ from_ptr += from_range.get(0);
+ }
+ }
+ to_surface += to_slice;
+ from_surface += from_slice;
+ }
+ break;
+ case host_to_device:
+ {
+ host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q,
+ event_list);
+ std::vector<sycl::event> host_events;
+ if (to_slice == size_slice)
+ {
+ // Copy host data to a temp host buffer with the shape of target.
+ host_events =
+ dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range,
+ sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size,
+ host_to_host, dep_events);
+ }
+ else
+ {
+ // Copy host data to a temp host buffer with the shape of target.
+ host_events = dpct_memcpy(
+ q, buf.get_ptr(), from_surface, to_range, from_range,
+ sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host,
+ // If has padding data, not sure whether it is useless. So fill temp
+ // buffer with it.
+ std::vector<sycl::event>{
+ dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(),
+ device_to_host, dep_events)});
+ }
+ // Copy from temp host buffer to device with only one submit.
+ event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(),
+ buf.get_size(), host_to_device,
+ host_events));
+ break;
+ }
+ case device_to_host:
+ {
+ host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q,
+ event_list);
+ // Copy from host temp buffer to host target with reshaping.
+ event_list = dpct_memcpy(
+ q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0),
+ sycl::id<3>(0, 0, 0), size, host_to_host,
+ // Copy from device to temp host buffer with only one submit.
+ std::vector<sycl::event>{dpct_memcpy(q, buf.get_ptr(), from_surface,
+ buf.get_size(),
+ device_to_host, dep_events)});
+ break;
+ }
+ case device_to_device:
+ event_list.push_back(q.submit([&](sycl::handler &cgh){
+ cgh.depends_on(dep_events);
+ cgh.parallel_for<class dpct_memcpy_3d_detail>(
+ size,
+ [=](sycl::id<3> id) {
+ to_surface[get_offset(id, to_slice, to_range.get(0))] =
+ from_surface[get_offset(id, from_slice, from_range.get(0))];
+ }); }));
+ break;
+ default:
+ throw std::runtime_error("dpct_memcpy: invalid direction value");
+ }
+ return event_list;
+ }
+
+ /// memcpy 2D/3D matrix specified by pitched_data.
+ static inline std::vector<sycl::event>
+ dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id,
+ pitched_data from, sycl::id<3> from_id, sycl::range<3> size,
+ memcpy_direction direction = automatic)
+ {
+ return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(),
+ sycl::range<3>(to.get_pitch(), to.get_y(), 1),
+ sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id,
+ size, direction);
+ }
+
+ /// memcpy 2D matrix with pitch.
+ static inline std::vector<sycl::event>
+ dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
+ size_t to_pitch, size_t from_pitch, size_t x, size_t y,
+ memcpy_direction direction = automatic)
+ {
+ return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1),
+ sycl::range<3>(from_pitch, y, 1),
+ sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0),
+ sycl::range<3>(x, y, 1), direction);
+ }
+
+ namespace deprecated
+ {
+
+ template <typename T, sycl::usm::alloc AllocKind>
+ class usm_allocator
+ {
+ private:
+ using Alloc = sycl::usm_allocator<T, AllocKind>;
+ Alloc _impl;
+
+ public:
+ using value_type = typename std::allocator_traits<Alloc>::value_type;
+ using pointer = typename std::allocator_traits<Alloc>::pointer;
+ using const_pointer = typename std::allocator_traits<Alloc>::const_pointer;
+ using void_pointer = typename std::allocator_traits<Alloc>::void_pointer;
+ using const_void_pointer =
+ typename std::allocator_traits<Alloc>::const_void_pointer;
+ using reference = typename std::allocator_traits<Alloc>::value_type &;
+ using const_reference =
+ const typename std::allocator_traits<Alloc>::value_type &;
+ using difference_type =
+ typename std::allocator_traits<Alloc>::difference_type;
+ using size_type = typename std::allocator_traits<Alloc>::size_type;
+ using propagate_on_container_copy_assignment = typename std::allocator_traits<
+ Alloc>::propagate_on_container_copy_assignment;
+ using propagate_on_container_move_assignment = typename std::allocator_traits<
+ Alloc>::propagate_on_container_move_assignment;
+ using propagate_on_container_swap =
+ typename std::allocator_traits<Alloc>::propagate_on_container_swap;
+ using is_always_equal =
+ typename std::allocator_traits<Alloc>::is_always_equal;
+
+ template <typename U>
+ struct rebind
+ {
+ typedef usm_allocator<U, AllocKind> other;
+ };
+
+ usm_allocator() : _impl(dpct::get_default_queue()) {}
+ ~usm_allocator() {}
+ usm_allocator(const usm_allocator &other) : _impl(other._impl) {}
+ usm_allocator(usm_allocator &&other) : _impl(std::move(other._impl)) {}
+ pointer address(reference r) { return &r; }
+ const_pointer address(const_reference r) { return &r; }
+ pointer allocate(size_type cnt, const_void_pointer hint = nullptr)
+ {
+ return std::allocator_traits<Alloc>::allocate(_impl, cnt, hint);
+ }
+ void deallocate(pointer p, size_type cnt)
+ {
+ std::allocator_traits<Alloc>::deallocate(_impl, p, cnt);
+ }
+ size_type max_size() const
+ {
+ return std::allocator_traits<Alloc>::max_size(_impl);
+ }
+ bool operator==(const usm_allocator &other) const { return _impl == other._impl; }
+ bool operator!=(const usm_allocator &other) const { return _impl != other._impl; }
+ };
+
+ } // namespace deprecated
+
+ inline void dpct_free(void *ptr,
+ const sycl::queue &q)
+ {
+ if (ptr)
+ {
+ sycl::free(ptr, q.get_context());
+ }
+ }
+
+ template <typename T>
+ inline auto get_memory(const void *x)
+ {
+ T *new_x = reinterpret_cast<T *>(const_cast<void *>(x));
+ return new_x;
+ }
+
+ template <typename T>
+ inline typename DataType<T>::T2 get_value(const T *s, sycl::queue &q)
+ {
+ using Ty = typename DataType<T>::T2;
+ Ty s_h;
+ if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only)
+ detail::dpct_memcpy(q, (void *)&s_h, (const void *)s, sizeof(T), device_to_host)
+ .wait();
+ else
+ s_h = *reinterpret_cast<const Ty *>(s);
+ return s_h;
+ }
+
+ } // namespace detail
+
+ template <typename T>
+ inline auto get_value(const T *s, sycl::queue &q)
+ {
+ return detail::get_value(s, q);
+ }
+
+ namespace detail
+ {
+ template <class Ta, class Tb, class Tc, class Ts>
+ inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
+ oneapi::mkl::transpose b_trans, int m, int n, int k,
+ const void *alpha, const void *a, int lda, const void *b,
+ int ldb, const void *beta, void *c, int ldc)
+ {
+ Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
+ Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
+ auto data_a = get_memory<const Ta>(a);
+ auto data_b = get_memory<const Tb>(b);
+ auto data_c = get_memory<Tc>(c);
+ oneapi::mkl::blas::column_major::gemm(
+ q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
+ data_b, ldb, beta_value, data_c, ldc);
+ }
+
+ template <typename VecT, class BinaryOperation, class = void>
+ class vectorized_binary
+ {
+ public:
+ inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)
+ {
+ VecT v4;
+ for (size_t i = 0; i < v4.size(); ++i)
+ {
+ v4[i] = binary_op(a[i], b[i]);
+ }
+ return v4;
+ }
+ };
+
+ template <typename VecT, class BinaryOperation>
+ class vectorized_binary<
+ VecT, BinaryOperation,
+ std::void_t<std::invoke_result_t<BinaryOperation, VecT, VecT>>>
+ {
+ public:
+ inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)
+ {
+ return binary_op(a, b).template as<VecT>();
+ }
+ };
+
+ template <class Ta, class Tb, class Tc, class Ts>
+ inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
+ oneapi::mkl::transpose b_trans, int m, int n, int k,
+ const void *alpha, const void **a, int lda,
+ const void **b, int ldb, const void *beta, void **c,
+ int ldc, int batch_size)
+ {
+ struct matrix_info_t
+ {
+ oneapi::mkl::transpose transpose_info[2];
+ Ts value_info[2];
+ std::int64_t size_info[3];
+ std::int64_t ld_info[3];
+ std::int64_t groupsize_info;
+ };
+
+ Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
+ Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
+
+ matrix_info_t *matrix_info =
+ (matrix_info_t *)std::malloc(sizeof(matrix_info_t));
+ matrix_info->transpose_info[0] = a_trans;
+ matrix_info->transpose_info[1] = b_trans;
+ matrix_info->value_info[0] = alpha_value;
+ matrix_info->value_info[1] = beta_value;
+ matrix_info->size_info[0] = m;
+ matrix_info->size_info[1] = n;
+ matrix_info->size_info[2] = k;
+ matrix_info->ld_info[0] = lda;
+ matrix_info->ld_info[1] = ldb;
+ matrix_info->ld_info[2] = ldc;
+ matrix_info->groupsize_info = batch_size;
+
+ sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
+ q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
+ matrix_info->size_info, matrix_info->size_info + 1,
+ matrix_info->size_info + 2, matrix_info->value_info,
+ reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
+ reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
+ matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
+ matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
+
+ q.submit([&](sycl::handler &cgh)
+ {
+ cgh.depends_on(e);
+ cgh.host_task([=] { std::free(matrix_info); }); });
+ }
+
+ template <class Ta, class Tb, class Tc, class Ts>
+ inline void
+ gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
+ oneapi::mkl::transpose b_trans, int m, int n,
+ int k, const void *alpha, const void *a, int lda,
+ long long int stride_a, const void *b, int ldb,
+ long long int stride_b, const void *beta, void *c,
+ int ldc, long long int stride_c, int batch_size)
+ {
+ Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
+ Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
+ auto data_a = get_memory<const Ta>(a);
+ auto data_b = get_memory<const Tb>(b);
+ auto data_c = get_memory<Tc>(c);
+ oneapi::mkl::blas::column_major::gemm_batch(
+ q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
+ stride_a, data_b, ldb, stride_b, beta_value,
+ data_c, ldc, stride_c, batch_size);
+ }
+
+ } // namespace detail
+
+ template <typename VecT, class BinaryOperation>
+ inline unsigned vectorized_binary(unsigned a, unsigned b,
+ const BinaryOperation binary_op)
+ {
+ sycl::vec<unsigned, 1> v0{a}, v1{b};
+ auto v2 = v0.as<VecT>();
+ auto v3 = v1.as<VecT>();
+ auto v4 =
+ detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
+ v0 = v4.template as<sycl::vec<unsigned, 1>>();
+ return v0;
+ }
+
+ static void async_dpct_memcpy(void *to_ptr, const void *from_ptr, size_t size,
+ memcpy_direction direction = automatic,
+ sycl::queue &q = dpct::get_default_queue())
+ {
+ detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction);
+ }
+
+ static inline unsigned int select_device(unsigned int id)
+ {
+ dev_mgr::instance().select_device(id);
+ return id;
+ }
+
+ template <typename T>
+ T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask,
+ unsigned int logical_sub_group_size = 32)
+ {
+ unsigned int id = g.get_local_linear_id();
+ unsigned int start_index =
+ id / logical_sub_group_size * logical_sub_group_size;
+ unsigned int target_offset = (id % logical_sub_group_size) ^ mask;
+ return sycl::select_from_group(g, x,
+ target_offset < logical_sub_group_size
+ ? start_index + target_offset
+ : id);
+ }
+
+ template <typename T>
+ sycl::vec<T, 4> extract_and_sign_or_zero_extend4(T val)
+ {
+ return sycl::vec<T, 1>(val)
+ .template as<sycl::vec<
+ std::conditional_t<std::is_signed_v<T>, int8_t, uint8_t>, 4>>()
+ .template convert<T>();
+ }
+
+ template <typename T1, typename T2>
+ using dot_product_acc_t =
+ std::conditional_t<std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
+ uint32_t, int32_t>;
+
+ template <typename T1, typename T2, typename T3>
+ inline auto dp4a(T1 a, T2 b, T3 c)
+ {
+ dot_product_acc_t<T1, T2> res = c;
+ auto va = extract_and_sign_or_zero_extend4(a);
+ auto vb = extract_and_sign_or_zero_extend4(b);
+ res += va[0] * vb[0];
+ res += va[1] * vb[1];
+ res += va[2] * vb[2];
+ res += va[3] * vb[3];
+ return res;
+ }
+
+ struct sub_sat
+ {
+ template <typename T>
+ auto operator()(const T x, const T y) const
+ {
+ return sycl::sub_sat(x, y);
+ }
+ };
+
+ template <typename S, typename T>
+ inline T vectorized_min(T a, T b)
+ {
+ sycl::vec<T, 1> v0{a}, v1{b};
+ auto v2 = v0.template as<S>();
+ auto v3 = v1.template as<S>();
+ auto v4 = sycl::min(v2, v3);
+ v0 = v4.template as<sycl::vec<T, 1>>();
+ return v0;
+ }
+
+ inline float pow(const float a, const int b) { return sycl::pown(a, b); }
+ inline double pow(const double a, const int b) { return sycl::pown(a, b); }
+ inline float pow(const float a, const float b) { return sycl::pow(a, b); }
+ inline double pow(const double a, const double b) { return sycl::pow(a, b); }
+ template <typename T, typename U>
+ inline typename std::enable_if_t<std::is_floating_point_v<T>, T>
+ pow(const T a, const U b)
+ {
+ return sycl::pow(a, static_cast<T>(b));
+ }
+ template <typename T, typename U>
+ inline typename std::enable_if_t<!std::is_floating_point_v<T>, double>
+ pow(const T a, const U b)
+ {
+ return sycl::pow(static_cast<double>(a), static_cast<double>(b));
+ }
+
+ inline double min(const double a, const float b)
+ {
+ return sycl::fmin(a, static_cast<double>(b));
+ }
+ inline double min(const float a, const double b)
+ {
+ return sycl::fmin(static_cast<double>(a), b);
+ }
+ inline float min(const float a, const float b) { return sycl::fmin(a, b); }
+ inline double min(const double a, const double b) { return sycl::fmin(a, b); }
+ inline std::uint32_t min(const std::uint32_t a, const std::int32_t b)
+ {
+ return sycl::min(a, static_cast<std::uint32_t>(b));
+ }
+ inline std::uint32_t min(const std::int32_t a, const std::uint32_t b)
+ {
+ return sycl::min(static_cast<std::uint32_t>(a), b);
+ }
+ inline std::int32_t min(const std::int32_t a, const std::int32_t b)
+ {
+ return sycl::min(a, b);
+ }
+ inline std::uint32_t min(const std::uint32_t a, const std::uint32_t b)
+ {
+ return sycl::min(a, b);
+ }
+ inline std::uint64_t min(const std::uint64_t a, const std::int64_t b)
+ {
+ return sycl::min(a, static_cast<std::uint64_t>(b));
+ }
+ inline std::uint64_t min(const std::int64_t a, const std::uint64_t b)
+ {
+ return sycl::min(static_cast<std::uint64_t>(a), b);
+ }
+ inline std::int64_t min(const std::int64_t a, const std::int64_t b)
+ {
+ return sycl::min(a, b);
+ }
+ inline std::uint64_t min(const std::uint64_t a, const std::uint64_t b)
+ {
+ return sycl::min(a, b);
+ }
+ inline std::uint64_t min(const std::uint64_t a, const std::int32_t b)
+ {
+ return sycl::min(a, static_cast<std::uint64_t>(b));
+ }
+ inline std::uint64_t min(const std::int32_t a, const std::uint64_t b)
+ {
+ return sycl::min(static_cast<std::uint64_t>(a), b);
+ }
+ inline std::uint64_t min(const std::uint64_t a, const std::uint32_t b)
+ {
+ return sycl::min(a, static_cast<std::uint64_t>(b));
+ }
+ inline std::uint64_t min(const std::uint32_t a, const std::uint64_t b)
+ {
+ return sycl::min(static_cast<std::uint64_t>(a), b);
+ }
+ // max function overloads.
+ // For floating-point types, `float` or `double` arguments are acceptable.
+ // For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or
+ // `std::int64_t` type arguments are acceptable.
+ inline double max(const double a, const float b)
+ {
+ return sycl::fmax(a, static_cast<double>(b));
+ }
+ inline double max(const float a, const double b)
+ {
+ return sycl::fmax(static_cast<double>(a), b);
+ }
+ inline float max(const float a, const float b) { return sycl::fmax(a, b); }
+ inline double max(const double a, const double b) { return sycl::fmax(a, b); }
+ inline std::uint32_t max(const std::uint32_t a, const std::int32_t b)
+ {
+ return sycl::max(a, static_cast<std::uint32_t>(b));
+ }
+ inline std::uint32_t max(const std::int32_t a, const std::uint32_t b)
+ {
+ return sycl::max(static_cast<std::uint32_t>(a), b);
+ }
+ inline std::int32_t max(const std::int32_t a, const std::int32_t b)
+ {
+ return sycl::max(a, b);
+ }
+ inline std::uint32_t max(const std::uint32_t a, const std::uint32_t b)
+ {
+ return sycl::max(a, b);
+ }
+ inline std::uint64_t max(const std::uint64_t a, const std::int64_t b)
+ {
+ return sycl::max(a, static_cast<std::uint64_t>(b));
+ }
+ inline std::uint64_t max(const std::int64_t a, const std::uint64_t b)
+ {
+ return sycl::max(static_cast<std::uint64_t>(a), b);
+ }
+ inline std::int64_t max(const std::int64_t a, const std::int64_t b)
+ {
+ return sycl::max(a, b);
+ }
+ inline std::uint64_t max(const std::uint64_t a, const std::uint64_t b)
+ {
+ return sycl::max(a, b);
+ }
+ inline std::uint64_t max(const std::uint64_t a, const std::int32_t b)
+ {
+ return sycl::max(a, static_cast<std::uint64_t>(b));
+ }
+ inline std::uint64_t max(const std::int32_t a, const std::uint64_t b)
+ {
+ return sycl::max(static_cast<std::uint64_t>(a), b);
+ }
+ inline std::uint64_t max(const std::uint64_t a, const std::uint32_t b)
+ {
+ return sycl::max(a, static_cast<std::uint64_t>(b));
+ }
+ inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b)
+ {
+ return sycl::max(static_cast<std::uint64_t>(a), b);
+ }
+
+ inline void
+ has_capability_or_fail(const sycl::device &dev,
+ const std::initializer_list<sycl::aspect> &props)
+ {
+ for (const auto &it : props)
+ {
+ if (dev.has(it))
+ continue;
+ switch (it)
+ {
+ case sycl::aspect::fp64:
+ throw std::runtime_error("'double' is not supported in '" +
+ dev.get_info<sycl::info::device::name>() +
+ "' device");
+ break;
+ case sycl::aspect::fp16:
+ throw std::runtime_error("'half' is not supported in '" +
+ dev.get_info<sycl::info::device::name>() +
+ "' device");
+ break;
+ default:
+#define __SYCL_ASPECT(ASPECT, ID) \
+ case sycl::aspect::ASPECT: \
+ return #ASPECT;
+#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID)
+#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE)
+ auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string
+ {
+ switch (AspectNum)
+ {
+#include <sycl/info/aspects.def>
+#include <sycl/info/aspects_deprecated.def>
+ default:
+ return "unknown aspect";
+ }
+ };
+#undef __SYCL_ASPECT_DEPRECATED_ALIAS
+#undef __SYCL_ASPECT_DEPRECATED
+#undef __SYCL_ASPECT
+ throw std::runtime_error(
+ "'" + getAspectNameStr(it) + "' is not supported in '" +
+ dev.get_info<sycl::info::device::name>() + "' device");
+ }
+ break;
+ }
+ }
+
+ static inline unsigned int get_current_device_id()
+ {
+ return dev_mgr::instance().current_device_id();
+ }
+
+ static inline device_ext &get_current_device()
+ {
+ return dev_mgr::instance().current_device();
+ }
+
+ static inline device_ext &get_device(unsigned int id)
+ {
+ return dev_mgr::instance().get_device(id);
+ }
+
+ static inline sycl::queue &get_in_order_queue()
+ {
+ return dev_mgr::instance().current_device().in_order_queue();
+ }
+
+ static sycl::event
+ dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size,
+ memcpy_direction direction,
+ const std::vector<sycl::event> &dep_events = {})
+ {
+ if (!size)
+ return sycl::event{};
+ return q.memcpy(to_ptr, from_ptr, size, dep_events);
+ GGML_UNUSED(direction);
+ }
+
+ // Get actual copy range and make sure it will not exceed range.
+ static inline size_t get_copy_range(sycl::range<3> size, size_t slice,
+ size_t pitch)
+ {
+ return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0);
+ }
+
+ static inline size_t get_offset(sycl::id<3> id, size_t slice,
+ size_t pitch)
+ {
+ return slice * id.get(2) + pitch * id.get(1) + id.get(0);
+ }
+
+ /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr
+ /// and \p from_range to another specified by \p to_ptr and \p to_range.
+ static inline std::vector<sycl::event>
+ dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
+ sycl::range<3> to_range, sycl::range<3> from_range,
+ sycl::id<3> to_id, sycl::id<3> from_id,
+ sycl::range<3> size, memcpy_direction direction,
+ const std::vector<sycl::event> &dep_events = {})
+ {
+ // RAII for host pointer
+ class host_buffer
+ {
+ void *_buf;
+ size_t _size;
+ sycl::queue &_q;
+ const std::vector<sycl::event> &_deps; // free operation depends
+
+ public:
+ host_buffer(size_t size, sycl::queue &q,
+ const std::vector<sycl::event> &deps)
+ : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {}
+ void *get_ptr() const { return _buf; }
+ size_t get_size() const { return _size; }
+ ~host_buffer()
+ {
+ if (_buf)
+ {
+ _q.submit([&](sycl::handler &cgh)
+ {
+ cgh.depends_on(_deps);
+ cgh.host_task([buf = _buf] { std::free(buf); }); });
+ }
+ }
+ };
+ std::vector<sycl::event> event_list;
+
+ size_t to_slice = to_range.get(1) * to_range.get(0),
+ from_slice = from_range.get(1) * from_range.get(0);
+ unsigned char *to_surface =
+ (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0));
+ const unsigned char *from_surface =
+ (const unsigned char *)from_ptr +
+ get_offset(from_id, from_slice, from_range.get(0));
+
+ if (to_slice == from_slice && to_slice == size.get(1) * size.get(0))
+ {
+ return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2),
+ direction, dep_events)};
+ }
+ direction = detail::deduce_memcpy_direction(q, to_ptr, from_ptr, direction);
+ size_t size_slice = size.get(1) * size.get(0);
+ switch (direction)
+ {
+ case host_to_host:
+ for (size_t z = 0; z < size.get(2); ++z)
+ {
+ unsigned char *to_ptr = to_surface;
+ const unsigned char *from_ptr = from_surface;
+ if (to_range.get(0) == from_range.get(0) &&
+ to_range.get(0) == size.get(0))
+ {
+ event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice,
+ direction, dep_events));
+ }
+ else
+ {
+ for (size_t y = 0; y < size.get(1); ++y)
+ {
+ event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0),
+ direction, dep_events));
+ to_ptr += to_range.get(0);
+ from_ptr += from_range.get(0);
+ }
+ }
+ to_surface += to_slice;
+ from_surface += from_slice;
+ }
+ break;
+ case host_to_device:
+ {
+ host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q,
+ event_list);
+ std::vector<sycl::event> host_events;
+ if (to_slice == size_slice)
+ {
+ // Copy host data to a temp host buffer with the shape of target.
+ host_events =
+ dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range,
+ sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size,
+ host_to_host, dep_events);
+ }
+ else
+ {
+ // Copy host data to a temp host buffer with the shape of target.
+ host_events = dpct_memcpy(
+ q, buf.get_ptr(), from_surface, to_range, from_range,
+ sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host,
+ // If has padding data, not sure whether it is useless. So fill temp
+ // buffer with it.
+ std::vector<sycl::event>{
+ dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(),
+ device_to_host, dep_events)});
+ }
+ // Copy from temp host buffer to device with only one submit.
+ event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(),
+ buf.get_size(), host_to_device,
+ host_events));
+ break;
+ }
+ case device_to_host:
+ {
+ host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q,
+ event_list);
+ // Copy from host temp buffer to host target with reshaping.
+ event_list = dpct_memcpy(
+ q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0),
+ sycl::id<3>(0, 0, 0), size, host_to_host,
+ // Copy from device to temp host buffer with only one submit.
+ std::vector<sycl::event>{dpct_memcpy(q, buf.get_ptr(), from_surface,
+ buf.get_size(),
+ device_to_host, dep_events)});
+ break;
+ }
+ case device_to_device:
+ event_list.push_back(q.submit([&](sycl::handler &cgh)
+ {
+ cgh.depends_on(dep_events);
+ cgh.parallel_for<class dpct_memcpy_3d_detail>(
+ size,
+ [=](sycl::id<3> id) {
+ to_surface[get_offset(id, to_slice, to_range.get(0))] =
+ from_surface[get_offset(id, from_slice, from_range.get(0))];
+ }); }));
+ break;
+ default:
+ throw std::runtime_error("dpct_memcpy: invalid direction value");
+ }
+ return event_list;
+ }
+
+ /// memcpy 2D/3D matrix specified by pitched_data.
+ static inline std::vector<sycl::event>
+ dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id,
+ pitched_data from, sycl::id<3> from_id, sycl::range<3> size,
+ memcpy_direction direction = automatic)
+ {
+ return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(),
+ sycl::range<3>(to.get_pitch(), to.get_y(), 1),
+ sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id,
+ size, direction);
+ }
+
+ /// memcpy 2D matrix with pitch.
+ static inline std::vector<sycl::event>
+ dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
+ size_t to_pitch, size_t from_pitch, size_t x, size_t y,
+ memcpy_direction direction = automatic)
+ {
+ return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1),
+ sycl::range<3>(from_pitch, y, 1),
+ sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0),
+ sycl::range<3>(x, y, 1), direction);
+ }
+
+ inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans,
+ oneapi::mkl::transpose b_trans, int m, int n, int k,
+ const void *alpha, const void *a, library_data_t a_type,
+ int lda, const void *b, library_data_t b_type, int ldb,
+ const void *beta, void *c, library_data_t c_type, int ldc,
+ library_data_t scaling_type)
+ {
+ if (scaling_type == library_data_t::real_float &&
+ c_type == library_data_t::complex_float)
+ {
+ scaling_type = library_data_t::complex_float;
+ }
+ else if (scaling_type == library_data_t::real_double &&
+ c_type == library_data_t::complex_double)
+ {
+ scaling_type = library_data_t::complex_double;
+ }
+
+ std::uint64_t key =
+ detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
+ switch (key)
+ {
+ case detail::get_type_combination_id(
+ library_data_t::real_float, library_data_t::real_float,
+ library_data_t::real_float, library_data_t::real_float):
+ {
+ detail::gemm_impl<float, float, float, float>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_double, library_data_t::real_double,
+ library_data_t::real_double, library_data_t::real_double):
+ {
+ detail::gemm_impl<double, double, double, double>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::complex_float, library_data_t::complex_float,
+ library_data_t::complex_float, library_data_t::complex_float):
+ {
+ detail::gemm_impl<std::complex<float>, std::complex<float>,
+ std::complex<float>, std::complex<float>>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::complex_double, library_data_t::complex_double,
+ library_data_t::complex_double, library_data_t::complex_double):
+ {
+ detail::gemm_impl<std::complex<double>, std::complex<double>,
+ std::complex<double>, std::complex<double>>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_half, library_data_t::real_half,
+ library_data_t::real_half, library_data_t::real_half):
+ {
+ detail::gemm_impl<sycl::half, sycl::half, sycl::half,
+ sycl::half>(q, a_trans, b_trans, m, n, k, alpha, a,
+ lda, b, ldb, beta, c, ldc);
+ break;
+ }
+#ifdef __INTEL_MKL__
+ case detail::get_type_combination_id(
+ library_data_t::real_bfloat16, library_data_t::real_bfloat16,
+ library_data_t::real_float, library_data_t::real_float):
+ {
+ detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
+ float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b,
+ ldb, beta, c, ldc);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_half, library_data_t::real_half,
+ library_data_t::real_float, library_data_t::real_float):
+ {
+ detail::gemm_impl<sycl::half, sycl::half, float, float>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_half, library_data_t::real_half,
+ library_data_t::real_half, library_data_t::real_float):
+ {
+ float alpha_value =
+ dpct::get_value(reinterpret_cast<const float *>(alpha), q);
+ float beta_value =
+ dpct::get_value(reinterpret_cast<const float *>(beta), q);
+ sycl::half alpha_half(alpha_value);
+ sycl::half beta_half(beta_value);
+ detail::gemm_impl<sycl::half, sycl::half, sycl::half,
+ sycl::half>(q, a_trans, b_trans, m, n, k, &alpha_half,
+ a, lda, b, ldb, &beta_half, c, ldc);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_int8, library_data_t::real_int8,
+ library_data_t::real_float, library_data_t::real_float):
+ {
+ detail::gemm_impl<std::int8_t, std::int8_t, float, float>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_bfloat16, library_data_t::real_bfloat16,
+ library_data_t::real_bfloat16, library_data_t::real_float):
+ {
+ detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
+ oneapi::mkl::bfloat16, float>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_int8, library_data_t::real_int8,
+ library_data_t::real_int32, library_data_t::real_int32):
+ {
+ float alpha_float =
+ dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
+ float beta_float =
+ dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
+ detail::gemm_impl<std::int8_t, std::int8_t, std::int32_t, float>(
+ q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
+ break;
+ }
+#endif // __INTEL_MKL__
+ default:
+ throw std::runtime_error("the combination of data type is unsupported");
+ }
+ } // gemm()
+
+ /// Computes a batch of matrix-matrix product with general matrices.
+ /// \param [in] q The queue where the routine should be executed.
+ /// \param [in] a_trans Specifies the operation applied to A.
+ /// \param [in] b_trans Specifies the operation applied to B.
+ /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C.
+ /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C.
+ /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B).
+ /// \param [in] alpha Scaling factor for the matrix-matrix product.
+ /// \param [in] a Input matrix A.
+ /// \param [in] a_type Data type of the matrix A.
+ /// \param [in] lda Leading dimension of A.
+ /// \param [in] b Input matrix B.
+ /// \param [in] b_type Data type of the matrix B.
+ /// \param [in] ldb Leading dimension of B.
+ /// \param [in] beta Scaling factor for matrix C.
+ /// \param [in, out] c Input/Output matrix C.
+ /// \param [in] c_type Data type of the matrix C.
+ /// \param [in] ldc Leading dimension of C.
+ /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
+ /// \param [in] scaling_type Data type of the scaling factors.
+ inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
+ oneapi::mkl::transpose b_trans, int m, int n, int k,
+ const void *alpha, const void *a[],
+ library_data_t a_type, int lda, const void *b[],
+ library_data_t b_type, int ldb, const void *beta,
+ void *c[], library_data_t c_type, int ldc,
+ int batch_size, library_data_t scaling_type)
+ {
+ if (scaling_type == library_data_t::real_float &&
+ c_type == library_data_t::complex_float)
+ {
+ scaling_type = library_data_t::complex_float;
+ }
+ else if (scaling_type == library_data_t::real_double &&
+ c_type == library_data_t::complex_double)
+ {
+ scaling_type = library_data_t::complex_double;
+ }
+
+ std::uint64_t key =
+ detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
+ switch (key)
+ {
+ case detail::get_type_combination_id(
+ library_data_t::real_float, library_data_t::real_float,
+ library_data_t::real_float, library_data_t::real_float):
+ {
+ detail::gemm_batch_impl<float, float, float, float>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_double, library_data_t::real_double,
+ library_data_t::real_double, library_data_t::real_double):
+ {
+ detail::gemm_batch_impl<double, double, double, double>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::complex_float, library_data_t::complex_float,
+ library_data_t::complex_float, library_data_t::complex_float):
+ {
+ detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
+ std::complex<float>, std::complex<float>>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::complex_double, library_data_t::complex_double,
+ library_data_t::complex_double, library_data_t::complex_double):
+ {
+ detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
+ std::complex<double>, std::complex<double>>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_half, library_data_t::real_half,
+ library_data_t::real_half, library_data_t::real_half):
+ {
+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
+ sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
+ a, lda, b, ldb, beta, c, ldc,
+ batch_size);
+ break;
+ }
+#ifdef __INTEL_MKL__
+ case detail::get_type_combination_id(
+ library_data_t::real_bfloat16, library_data_t::real_bfloat16,
+ library_data_t::real_bfloat16, library_data_t::real_float):
+ {
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
+ oneapi::mkl::bfloat16, float>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_bfloat16, library_data_t::real_bfloat16,
+ library_data_t::real_float, library_data_t::real_float):
+ {
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
+ float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
+ b, ldb, beta, c, ldc, batch_size);
+ break;
+ }
+#endif
+ case detail::get_type_combination_id(
+ library_data_t::real_int8, library_data_t::real_int8,
+ library_data_t::real_int32, library_data_t::real_int32):
+ {
+ float alpha_float =
+ dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
+ float beta_float =
+ dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
+ detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
+ float>(q, a_trans, b_trans, m, n, k, &alpha_float,
+ a, lda, b, ldb, &beta_float, c, ldc,
+ batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_int8, library_data_t::real_int8,
+ library_data_t::real_float, library_data_t::real_float):
+ {
+ detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_half, library_data_t::real_half,
+ library_data_t::real_float, library_data_t::real_float):
+ {
+ detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
+ batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_half, library_data_t::real_half,
+ library_data_t::real_half, library_data_t::real_float):
+ {
+ float alpha_value =
+ dpct::get_value(reinterpret_cast<const float *>(alpha), q);
+ float beta_value =
+ dpct::get_value(reinterpret_cast<const float *>(beta), q);
+ sycl::half alpha_half(alpha_value);
+ sycl::half beta_half(beta_value);
+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
+ q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
+ batch_size);
+ break;
+ }
+ default:
+ throw std::runtime_error("the combination of data type is unsupported");
+ }
+ }
+
+ /// Computes a batch of matrix-matrix product with general matrices.
+ /// \param [in] q The queue where the routine should be executed.
+ /// \param [in] a_trans Specifies the operation applied to A.
+ /// \param [in] b_trans Specifies the operation applied to B.
+ /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C.
+ /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C.
+ /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B).
+ /// \param [in] alpha Scaling factor for the matrix-matrix product.
+ /// \param [in] a Input matrix A.
+ /// \param [in] a_type Data type of the matrix A.
+ /// \param [in] lda Leading dimension of A.
+ /// \param [in] stride_a Stride between the different A matrices.
+ /// \param [in] b Input matrix B.
+ /// \param [in] b_type Data type of the matrix B.
+ /// \param [in] ldb Leading dimension of B.
+ /// \param [in] stride_b Stride between the different B matrices.
+ /// \param [in] beta Scaling factor for matrix C.
+ /// \param [in, out] c Input/Output matrix C.
+ /// \param [in] c_type Data type of the matrix C.
+ /// \param [in] ldc Leading dimension of C.
+ /// \param [in] stride_c Stride between the different C matrices.
+ /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
+ /// \param [in] scaling_type Data type of the scaling factors.
+ inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
+ oneapi::mkl::transpose b_trans, int m, int n, int k,
+ const void *alpha, const void *a, library_data_t a_type,
+ int lda, long long int stride_a, const void *b,
+ library_data_t b_type, int ldb, long long int stride_b,
+ const void *beta, void *c, library_data_t c_type,
+ int ldc, long long int stride_c, int batch_size,
+ library_data_t scaling_type)
+ {
+ if (scaling_type == library_data_t::real_float &&
+ c_type == library_data_t::complex_float)
+ {
+ scaling_type = library_data_t::complex_float;
+ }
+ else if (scaling_type == library_data_t::real_double &&
+ c_type == library_data_t::complex_double)
+ {
+ scaling_type = library_data_t::complex_double;
+ }
+
+ std::uint64_t key =
+ detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
+ switch (key)
+ {
+ case detail::get_type_combination_id(
+ library_data_t::real_float, library_data_t::real_float,
+ library_data_t::real_float, library_data_t::real_float):
+ {
+ detail::gemm_batch_impl<float, float, float, float>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
+ beta, c, ldc, stride_c, batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_double, library_data_t::real_double,
+ library_data_t::real_double, library_data_t::real_double):
+ {
+ detail::gemm_batch_impl<double, double, double, double>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
+ beta, c, ldc, stride_c, batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::complex_float, library_data_t::complex_float,
+ library_data_t::complex_float, library_data_t::complex_float):
+ {
+ detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
+ std::complex<float>, std::complex<float>>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
+ beta, c, ldc, stride_c, batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::complex_double, library_data_t::complex_double,
+ library_data_t::complex_double, library_data_t::complex_double):
+ {
+ detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
+ std::complex<double>, std::complex<double>>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
+ beta, c, ldc, stride_c, batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_half, library_data_t::real_half,
+ library_data_t::real_half, library_data_t::real_half):
+ {
+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
+ sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
+ a, lda, stride_a, b, ldb, stride_b,
+ beta, c, ldc, stride_c, batch_size);
+ break;
+ }
+#ifdef __INTEL_MKL__
+ case detail::get_type_combination_id(
+ library_data_t::real_bfloat16, library_data_t::real_bfloat16,
+ library_data_t::real_bfloat16, library_data_t::real_float):
+ {
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
+ oneapi::mkl::bfloat16, float>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
+ beta, c, ldc, stride_c, batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_bfloat16, library_data_t::real_bfloat16,
+ library_data_t::real_float, library_data_t::real_float):
+ {
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
+ float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
+ stride_a, b, ldb, stride_b, beta, c, ldc,
+ stride_c, batch_size);
+ break;
+ }
+#endif
+ case detail::get_type_combination_id(
+ library_data_t::real_int8, library_data_t::real_int8,
+ library_data_t::real_int32, library_data_t::real_int32):
+ {
+ detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
+ std::int32_t>(q, a_trans, b_trans, m, n, k, alpha,
+ a, lda, stride_a, b, ldb, stride_b,
+ beta, c, ldc, stride_c, batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_int8, library_data_t::real_int8,
+ library_data_t::real_float, library_data_t::real_float):
+ {
+ detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
+ beta, c, ldc, stride_c, batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_half, library_data_t::real_half,
+ library_data_t::real_float, library_data_t::real_float):
+ {
+ detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
+ beta, c, ldc, stride_c, batch_size);
+ break;
+ }
+ case detail::get_type_combination_id(
+ library_data_t::real_half, library_data_t::real_half,
+ library_data_t::real_half, library_data_t::real_float):
+ {
+ float alpha_value =
+ dpct::get_value(reinterpret_cast<const float *>(alpha), q);
+ float beta_value =
+ dpct::get_value(reinterpret_cast<const float *>(beta), q);
+ sycl::half alpha_half(alpha_value);
+ sycl::half beta_half(beta_value);
+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
+ q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b,
+ &beta_half, c, ldc, stride_c, batch_size);
+ break;
+ }
+ default:
+ throw std::runtime_error("the combination of data type is unsupported");
+ }
+ }
+
+ static inline void
+ async_dpct_memcpy(void *to_ptr, size_t to_pitch, const void *from_ptr,
+ size_t from_pitch, size_t x, size_t y,
+ memcpy_direction direction = automatic,
+ sycl::queue &q = get_default_queue())
+ {
+ detail::dpct_memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y,
+ direction);
+ }
+
+ using err0 = detail::generic_error_type<struct err0_tag, int>;
+ using err1 = detail::generic_error_type<struct err1_tag, int>;
+
+ static inline void dpct_free(void *ptr, sycl::queue &q = get_default_queue()) {
+ detail::dpct_free(ptr, q);
+ }
+
+ /// dpct accessor used as device function parameter.
+ template <class T, memory_region Memory, size_t Dimension> class accessor;
+ template <class T, memory_region Memory> class accessor<T, Memory, 3> {
+ public:
+ using memory_t = detail::memory_traits<Memory, T>;
+ using element_t = typename memory_t::element_t;
+ using pointer_t = typename memory_t::pointer_t;
+ using accessor_t = typename memory_t::template accessor_t<3>;
+ accessor(pointer_t data, const sycl::range<3> &in_range)
+ : _data(data), _range(in_range) {}
+ template <memory_region M = Memory>
+ accessor(typename std::enable_if<M != local, const accessor_t>::type &acc)
+ : accessor(acc, acc.get_range()) {}
+ accessor(const accessor_t &acc, const sycl::range<3> &in_range)
+ : accessor(acc.get_pointer(), in_range) {}
+ accessor<T, Memory, 2> operator[](size_t index) const {
+ sycl::range<2> sub(_range.get(1), _range.get(2));
+ return accessor<T, Memory, 2>(_data + index * sub.size(), sub);
+ }
+
+ pointer_t get_ptr() const { return _data; }
+
+ private:
+ pointer_t _data;
+ sycl::range<3> _range;
+ };
+ template <class T, memory_region Memory> class accessor<T, Memory, 2> {
+ public:
+ using memory_t = detail::memory_traits<Memory, T>;
+ using element_t = typename memory_t::element_t;
+ using pointer_t = typename memory_t::pointer_t;
+ using accessor_t = typename memory_t::template accessor_t<2>;
+ accessor(pointer_t data, const sycl::range<2> &in_range)
+ : _data(data), _range(in_range) {}
+ template <memory_region M = Memory>
+ accessor(typename std::enable_if<M != local, const accessor_t>::type &acc)
+ : accessor(acc, acc.get_range()) {}
+ accessor(const accessor_t &acc, const sycl::range<2> &in_range)
+ : accessor(acc.get_pointer(), in_range) {}
+
+ pointer_t operator[](size_t index) const {
+ return _data + _range.get(1) * index;
+ }
+
+ pointer_t get_ptr() const { return _data; }
+
+ private:
+ pointer_t _data;
+ sycl::range<2> _range;
+ };
+
+ namespace detail {
+ /// Device variable with address space of shared, global or constant.
+ template <class T, memory_region Memory, size_t Dimension> class device_memory {
+ public:
+ using accessor_t =
+ typename detail::memory_traits<Memory,
+ T>::template accessor_t<Dimension>;
+ using value_t = typename detail::memory_traits<Memory, T>::value_t;
+ using dpct_accessor_t = dpct::accessor<T, Memory, Dimension>;
+
+ device_memory() : device_memory(sycl::range<Dimension>(1)) {}
+
+ /// Constructor of 1-D array with initializer list
+ device_memory(const sycl::range<Dimension> &in_range,
+ std::initializer_list<value_t> &&init_list)
+ : device_memory(in_range) {
+ assert(init_list.size() <= in_range.size());
+ _host_ptr = (value_t *)std::malloc(_size);
+ std::memset(_host_ptr, 0, _size);
+ std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T));
+ }
+
+ /// Constructor of 2-D array with initializer list
+ template <size_t D = Dimension>
+ device_memory(
+ const typename std::enable_if<D == 2, sycl::range<2>>::type &in_range,
+ std::initializer_list<std::initializer_list<value_t>> &&init_list)
+ : device_memory(in_range) {
+ assert(init_list.size() <= in_range[0]);
+ _host_ptr = (value_t *)std::malloc(_size);
+ std::memset(_host_ptr, 0, _size);
+ auto tmp_data = _host_ptr;
+ for (auto sub_list : init_list) {
+ assert(sub_list.size() <= in_range[1]);
+ std::memcpy(tmp_data, sub_list.begin(),
+ sub_list.size() * sizeof(T));
+ tmp_data += in_range[1];
+ }
+ }
+
+ /// Constructor with range
+ device_memory(const sycl::range<Dimension> &range_in)
+ : _size(range_in.size() * sizeof(T)), _range(range_in),
+ _reference(false), _host_ptr(nullptr), _device_ptr(nullptr) {
+ static_assert(
+ (Memory == global) || (Memory == constant) || (Memory == shared),
+ "device memory region should be global, constant or shared");
+ // Make sure that singleton class mem_mgr and dev_mgr will destruct
+ // later than this.
+ detail::mem_mgr::instance();
+ dev_mgr::instance();
+ }
+
+ /// Constructor with range
+ template <class... Args>
+ device_memory(Args... Arguments)
+ : device_memory(sycl::range<Dimension>(Arguments...)) {}
+
+ ~device_memory() {
+ if (_device_ptr && !_reference)
+ dpct::dpct_free(_device_ptr);
+ if (_host_ptr)
+ std::free(_host_ptr);
+ }
+
+ /// Allocate memory with default queue, and init memory if has initial
+ /// value.
+ void init() { init(dpct::get_default_queue()); }
+ /// Allocate memory with specified queue, and init memory if has initial
+ /// value.
+ void init(sycl::queue &q) {
+ if (_device_ptr)
+ return;
+ if (!_size)
+ return;
+ allocate_device(q);
+ if (_host_ptr)
+ detail::dpct_memcpy(q, _device_ptr, _host_ptr, _size,
+ host_to_device);
+ }
+
+ /// The variable is assigned to a device pointer.
+ void assign(value_t *src, size_t size) {
+ this->~device_memory();
+ new (this) device_memory(src, size);
+ }
+
+ /// Get memory pointer of the memory object, which is virtual pointer when
+ /// usm is not used, and device pointer when usm is used.
+ value_t *get_ptr() { return get_ptr(get_default_queue()); }
+ /// Get memory pointer of the memory object, which is virtual pointer when
+ /// usm is not used, and device pointer when usm is used.
+ value_t *get_ptr(sycl::queue &q) {
+ init(q);
+ return _device_ptr;
+ }
+
+ /// Get the device memory object size in bytes.
+ size_t get_size() { return _size; }
+
+ template <size_t D = Dimension>
+ typename std::enable_if<D == 1, T>::type &operator[](size_t index) {
+ init();
+ return _device_ptr[index];
+ }
+
+ /// Get dpct::accessor with dimension info for the device memory object
+ /// when usm is used and dimension is greater than 1.
+ template <size_t D = Dimension>
+ typename std::enable_if<D != 1, dpct_accessor_t>::type
+ get_access([[maybe_unused]] sycl::handler &cgh) {
+ return dpct_accessor_t((T *)_device_ptr, _range);
+ }
+
+ private:
+ device_memory(value_t *memory_ptr, size_t size)
+ : _size(size), _range(size / sizeof(T)), _reference(true),
+ _device_ptr(memory_ptr) {}
+
+ void allocate_device(sycl::queue &q) {
+ #ifndef DPCT_USM_LEVEL_NONE
+ if (Memory == shared) {
+ _device_ptr = (value_t *)sycl::malloc_shared(_size, q.get_device(),
+ q.get_context());
+ return;
+ }
+ #ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY
+ if (Memory == constant) {
+ _device_ptr = (value_t *)sycl::malloc_device(
+ _size, q.get_device(), q.get_context(),
+ sycl::ext::oneapi::property::usm::device_read_only());
+ return;
+ }
+ #endif
+ #endif
+ _device_ptr = (value_t *)detail::dpct_malloc(_size, q);
+ }
+
+ size_t _size;
+ sycl::range<Dimension> _range;
+ bool _reference;
+ value_t *_host_ptr;
+ value_t *_device_ptr;
+ };
+ template <class T, memory_region Memory>
+ class device_memory<T, Memory, 0> : public device_memory<T, Memory, 1> {
+ public:
+ using base = device_memory<T, Memory, 1>;
+ using value_t = typename base::value_t;
+ using accessor_t =
+ typename detail::memory_traits<Memory, T>::template accessor_t<0>;
+
+ /// Constructor with initial value.
+ device_memory(const value_t &val) : base(sycl::range<1>(1), {val}) {}
+
+ /// Default constructor
+ device_memory() : base(1) {}
+ };
+ } // namespace detail
+
+ template <class T, size_t Dimension>
+ using global_memory = detail::device_memory<T, global, Dimension>;
+ template <class T, size_t Dimension>
+ using constant_memory = detail::device_memory<T, constant, Dimension>;
+ template <class T, size_t Dimension>
+ using shared_memory = detail::device_memory<T, shared, Dimension>;
+
+
+ template <typename T,
+ sycl::access::address_space addressSpace =
+ sycl::access::address_space::global_space,
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
+ sycl::memory_scope memoryScope = sycl::memory_scope::device>
+ inline T atomic_fetch_add(T *addr, T operand) {
+ auto atm =
+ sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
+ return atm.fetch_add(operand);
+ }
+
+ template <sycl::access::address_space addressSpace =
+ sycl::access::address_space::global_space,
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
+ sycl::memory_scope memoryScope = sycl::memory_scope::device,
+ typename T1, typename T2>
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
+ auto atm =
+ sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
+ return atm.fetch_add(operand);
+ }
+
+ template <typename T, sycl::access::address_space addressSpace =
+ sycl::access::address_space::global_space>
+ inline T atomic_fetch_add(T *addr, T operand,
+ sycl::memory_order memoryOrder) {
+ switch (memoryOrder) {
+ case sycl::memory_order::relaxed:
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
+ sycl::memory_scope::device>(addr, operand);
+ case sycl::memory_order::acq_rel:
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
+ sycl::memory_scope::device>(addr, operand);
+ case sycl::memory_order::seq_cst:
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
+ sycl::memory_scope::device>(addr, operand);
+ default:
+ assert(false && "Invalid memory_order for atomics. Valid memory_order for "
+ "atomics are: sycl::memory_order::relaxed, "
+ "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
+ }
+ }
+
+ template <sycl::access::address_space addressSpace =
+ sycl::access::address_space::global_space,
+ typename T1, typename T2>
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand,
+ sycl::memory_order memoryOrder) {
+ atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
+ }
+
+} // COPY from DPCT head files
+
+#endif // GGML_SYCL_DPCT_HELPER_HPP
diff --git a/ggml/src/ggml-sycl/mmq.cpp b/ggml/src/ggml-sycl/mmq.cpp
new file mode 100644
index 00000000..3107ba91
--- /dev/null
+++ b/ggml/src/ggml-sycl/mmq.cpp
@@ -0,0 +1,3031 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#include "mmq.hpp"
+#include "vecdotq.hpp"
+
+typedef void (*allocate_tiles_sycl_t)(
+ int** x_ql,
+ sycl::half2** x_dm,
+ int** x_qh,
+ int** x_sc);
+typedef void (*load_tiles_sycl_t)(
+ const void* __restrict__ vx,
+ int* __restrict__ x_ql,
+ sycl::half2* __restrict__ x_dm,
+ int* __restrict__ x_qh,
+ int* __restrict__ x_sc,
+ const int& i_offset,
+ const int& i_max,
+ const int& k,
+ const int& blocks_per_row);
+typedef float (*vec_dot_q_mul_mat_sycl_t)(
+ const int* __restrict__ x_ql,
+ const sycl::half2* __restrict__ x_dm,
+ const int* __restrict__ x_qh,
+ const int* __restrict__ x_sc,
+ const int* __restrict__ y_qs,
+ const sycl::half2* __restrict__ y_ms,
+ const int& i,
+ const int& j,
+ const int& k);
+
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q4_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+ int *tile_x_qs_q4_0, float *tile_x_d_q4_0) {
+ (void)x_qh; (void)x_sc;
+
+ *x_ql = tile_x_qs_q4_0;
+ *x_dm = (sycl::half2 *)tile_x_d_q4_0;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q4_0(const void *__restrict__ vx, int *__restrict__ x_ql,
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+ const int &k, const int &blocks_per_row) {
+ (void)x_qh; (void)x_sc;
+ GGML_SYCL_ASSUME(i_offset >= 0);
+ GGML_SYCL_ASSUME(i_offset < nwarps);
+ GGML_SYCL_ASSUME(k >= 0);
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
+
+ const int kbx = k / QI4_0;
+ const int kqsx = k % QI4_0;
+
+ const block_q4_0 * bx0 = (const block_q4_0 *) vx;
+
+ float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + i_offset;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
+ // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
+ const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
+ int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+ x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
+ }
+}
+
+static __dpct_inline__ float vec_dot_q4_0_q8_1_mul_mat(
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+ const int &i, const int &j, const int &k) {
+ (void)x_qh; (void)x_sc;
+
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+ const float * x_dmf = (const float *) x_dm;
+
+ int u[2*VDR_Q4_0_Q8_1_MMQ];
+
+#pragma unroll
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
+ }
+
+ return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
+ (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
+ y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q4_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+ int *tile_x_qs_q4_1, sycl::half2 *tile_x_dm_q4_1) {
+ (void)x_qh; (void)x_sc;
+
+ *x_ql = tile_x_qs_q4_1;
+ *x_dm = tile_x_dm_q4_1;
+}
+
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q4_1(const void *__restrict__ vx, int *__restrict__ x_ql,
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+ const int &k, const int &blocks_per_row) {
+ (void)x_qh; (void)x_sc;
+
+ GGML_SYCL_ASSUME(i_offset >= 0);
+ GGML_SYCL_ASSUME(i_offset < nwarps);
+ GGML_SYCL_ASSUME(k >= 0);
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
+
+ const int kbx = k / QI4_1;
+ const int kqsx = k % QI4_1;
+
+ const block_q4_1 * bx0 = (const block_q4_1 *) vx;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + i_offset;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
+ const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
+ int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+ x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
+ }
+}
+
+static __dpct_inline__ float vec_dot_q4_1_q8_1_mul_mat(
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+ const int &i, const int &j, const int &k) {
+ (void)x_qh; (void)x_sc;
+
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+
+ int u[2*VDR_Q4_1_Q8_1_MMQ];
+
+#pragma unroll
+ for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
+ }
+
+ return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
+ (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
+ y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q5_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+ int *tile_x_ql_q5_0, float *tile_x_d_q5_0) {
+ (void)x_qh; (void)x_sc;
+
+ *x_ql = tile_x_ql_q5_0;
+ *x_dm = (sycl::half2 *)tile_x_d_q5_0;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q5_0(const void *__restrict__ vx, int *__restrict__ x_ql,
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+ const int &k, const int &blocks_per_row) {
+ (void)x_qh; (void)x_sc;
+
+ GGML_SYCL_ASSUME(i_offset >= 0);
+ GGML_SYCL_ASSUME(i_offset < nwarps);
+ GGML_SYCL_ASSUME(k >= 0);
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
+
+ const int kbx = k / QI5_0;
+ const int kqsx = k % QI5_0;
+
+ const block_q5_0 * bx0 = (const block_q5_0 *) vx;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + i_offset;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
+
+ const int ql = get_int_from_uint8(bxi->qs, kqsx);
+ const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
+
+ int qs0 = (ql >> 0) & 0x0F0F0F0F;
+ qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
+ qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
+ qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
+ qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
+ qs0 = dpct::vectorized_binary<sycl::char4>(
+ qs0, 0x10101010, dpct::sub_sat()); // subtract 16
+
+ x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
+
+ int qs1 = (ql >> 4) & 0x0F0F0F0F;
+ qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
+ qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
+ qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
+ qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
+ qs1 = dpct::vectorized_binary<sycl::char4>(
+ qs1, 0x10101010, dpct::sub_sat()); // subtract 16
+
+ x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
+ const int kbxd = k % blocks_per_tile_x_row;
+ float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
+ int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+ x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
+ }
+}
+
+static __dpct_inline__ float vec_dot_q5_0_q8_1_mul_mat(
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+ const int &i, const int &j, const int &k) {
+ (void)x_qh; (void)x_sc;
+
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+ const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
+ const float * x_dmf = (const float *) x_dm;
+ const float * y_df = (const float *) y_ds;
+
+ int u[2*VDR_Q5_0_Q8_1_MMQ];
+
+#pragma unroll
+ for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
+ }
+
+ return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
+ (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q5_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+ int *tile_x_ql_q5_1, sycl::half2 *tile_x_dm_q5_1) {
+ (void)x_qh; (void)x_sc;
+
+ *x_ql = tile_x_ql_q5_1;
+ *x_dm = tile_x_dm_q5_1;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q5_1(const void *__restrict__ vx, int *__restrict__ x_ql,
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+ const int &k, const int &blocks_per_row) {
+ (void)x_qh; (void)x_sc;
+
+ GGML_SYCL_ASSUME(i_offset >= 0);
+ GGML_SYCL_ASSUME(i_offset < nwarps);
+ GGML_SYCL_ASSUME(k >= 0);
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
+
+ const int kbx = k / QI5_1;
+ const int kqsx = k % QI5_1;
+
+ const block_q5_1 * bx0 = (const block_q5_1 *) vx;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + i_offset;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
+
+ const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+ const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
+
+ int qs0 = (ql >> 0) & 0x0F0F0F0F;
+ qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
+ qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
+ qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
+ qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
+
+ x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
+
+ int qs1 = (ql >> 4) & 0x0F0F0F0F;
+ qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
+ qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
+ qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
+ qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
+
+ x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
+ const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
+ int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+ x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
+ }
+}
+
+static __dpct_inline__ float vec_dot_q5_1_q8_1_mul_mat(
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+ const int &i, const int &j, const int &k) {
+ (void)x_qh; (void)x_sc;
+
+ const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
+ const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
+
+ int u[2*VDR_Q5_1_Q8_1_MMQ];
+
+#pragma unroll
+ for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
+ }
+
+ return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
+ (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q8_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+ int *tile_x_qs_q8_0, float *tile_x_d_q8_0) {
+ (void)x_qh; (void)x_sc;
+
+ *x_ql = tile_x_qs_q8_0;
+ *x_dm = (sycl::half2 *)tile_x_d_q8_0;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q8_0(const void *__restrict__ vx, int *__restrict__ x_ql,
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+ const int &k, const int &blocks_per_row) {
+ (void)x_qh; (void)x_sc;
+
+ GGML_SYCL_ASSUME(i_offset >= 0);
+ GGML_SYCL_ASSUME(i_offset < nwarps);
+ GGML_SYCL_ASSUME(k >= 0);
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
+
+ const int kbx = k / QI8_0;
+ const int kqsx = k % QI8_0;
+ float * x_dmf = (float *) x_dm;
+
+ const block_q8_0 * bx0 = (const block_q8_0 *) vx;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + i_offset;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
+ const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
+ int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+
+ x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
+ }
+}
+
+static __dpct_inline__ float vec_dot_q8_0_q8_1_mul_mat(
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+ const int &i, const int &j, const int &k) {
+ (void)x_qh; (void)x_sc;
+
+ const float * x_dmf = (const float *) x_dm;
+ const float * y_df = (const float *) y_ds;
+
+ return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
+ (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
+ y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q2_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+ int *tile_x_ql_q2_K, sycl::half2 *tile_x_dm_q2_K,
+ int *tile_x_sc_q2_K) {
+ (void)x_qh;
+
+ *x_ql = tile_x_ql_q2_K;
+ *x_dm = tile_x_dm_q2_K;
+ *x_sc = tile_x_sc_q2_K;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q2_K(const void *__restrict__ vx, int *__restrict__ x_ql,
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+ const int &k, const int &blocks_per_row) {
+ (void)x_qh;
+
+ GGML_SYCL_ASSUME(i_offset >= 0);
+ GGML_SYCL_ASSUME(i_offset < nwarps);
+ GGML_SYCL_ASSUME(k >= 0);
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
+
+ const int kbx = k / QI2_K;
+ const int kqsx = k % QI2_K;
+
+ const block_q2_K * bx0 = (const block_q2_K *) vx;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + i_offset;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
+ const int kbxd = k % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
+ int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
+
+ x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+ int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
+
+ x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
+ }
+}
+
+#define VDR_Q2_K_Q8_1_MMQ 2
+// contiguous u/y values
+static __dpct_inline__ float
+vec_dot_q2_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
+ const uint8_t *__restrict__ scales,
+ const sycl::half2 &dm2, const float &d8) {
+
+ int sumi_d = 0;
+ int sumi_m = 0;
+
+#pragma unroll
+ for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
+ int sumi_d_sc = 0;
+
+ const int sc = scales[i0 / (QI8_1/2)];
+
+ // fill int with 4x m
+ int m = sc >> 4;
+ m |= m << 8;
+ m |= m << 16;
+
+#pragma unroll
+ for (int i = i0; i < i0 + QI8_1/2; ++i) {
+ sumi_d_sc = dpct::dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
+ sumi_m = dpct::dp4a(m, u[i],
+ sumi_m); // multiply sum of q8_1 values with m
+ }
+
+ sumi_d += sumi_d_sc * (sc & 0xF);
+ }
+
+ const sycl::float2 dm2f =
+ dm2.convert<float, sycl::rounding_mode::automatic>();
+
+ return d8 * (dm2f.x() * sumi_d - dm2f.y() * sumi_m);
+}
+
+static __dpct_inline__ float vec_dot_q2_K_q8_1_mul_mat(
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+ const int &i, const int &j, const int &k) {
+ (void)x_qh;
+
+ const int kbx = k / QI2_K;
+ const int ky = (k % QI2_K) * QR2_K;
+ const float * y_df = (const float *) y_ds;
+
+ int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
+
+ const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
+ const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
+
+#pragma unroll
+ for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
+ v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
+ }
+
+ const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
+
+ const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
+ return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q3_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+ int *tile_x_ql_q3_K, sycl::half2 *tile_x_dm_q3_K,
+ int *tile_x_qh_q3_K, int *tile_x_sc_q3_K) {
+
+ *x_ql = tile_x_ql_q3_K;
+ *x_dm = tile_x_dm_q3_K;
+ *x_qh = tile_x_qh_q3_K;
+ *x_sc = tile_x_sc_q3_K;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q3_K(const void *__restrict__ vx, int *__restrict__ x_ql,
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+ const int &k, const int &blocks_per_row) {
+
+ GGML_SYCL_ASSUME(i_offset >= 0);
+ GGML_SYCL_ASSUME(i_offset < nwarps);
+ GGML_SYCL_ASSUME(k >= 0);
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
+
+ const int kbx = k / QI3_K;
+ const int kqsx = k % QI3_K;
+
+ const block_q3_K * bx0 = (const block_q3_K *) vx;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + i_offset;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
+ const int kbxd = k % blocks_per_tile_x_row;
+ float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
+ int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
+
+ x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
+ int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
+
+ // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+ x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+ int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
+
+ const int ksc = k % (QI3_K/4);
+
+ const int ksc_low = ksc % (QI3_K/8);
+ const int shift_low = 4 * (ksc / (QI3_K/8));
+ const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
+
+ const int ksc_high = QI3_K/8;
+ const int shift_high = 2 * ksc;
+ const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
+
+ const int sc = dpct::vectorized_binary<sycl::char4>(
+ sc_low | sc_high, 0x20202020, dpct::sub_sat());
+
+ x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
+ }
+}
+
+#define VDR_Q3_K_Q8_1_MMQ 2
+// contiguous u/y values
+static __dpct_inline__ float
+vec_dot_q3_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
+ const int8_t *__restrict__ scales, const float &d3,
+ const float &d8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
+ int sumi_sc = 0;
+
+ for (int i = i0; i < i0 + QI8_1/2; ++i) {
+ sumi_sc = dpct::dp4a(v[i], u[i], sumi_sc); // SIMD dot product
+ }
+
+ sumi += sumi_sc * scales[i0 / (QI8_1/2)];
+ }
+
+ return d3*d8 * sumi;
+}
+
+static __dpct_inline__ float vec_dot_q3_K_q8_1_mul_mat(
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+ const int &i, const int &j, const int &k) {
+
+ const int kbx = k / QI3_K;
+ const int ky = (k % QI3_K) * QR3_K;
+ const float * x_dmf = (const float *) x_dm;
+ const float * y_df = (const float *) y_ds;
+
+ const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
+
+ int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
+
+#pragma unroll
+ for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
+ const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
+ const int shift = 2 * ((ky % 32) / 8);
+ const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
+
+ const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
+ const int vlh = (vh << 2) & 0x04040404;
+
+ v[l] = dpct::vectorized_binary<sycl::char4>(vll, vlh, dpct::sub_sat());
+ }
+
+ const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
+ return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q4_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+ int *tile_x_ql_q4_K, sycl::half2 *tile_x_dm_q4_K,
+ int *tile_x_sc_q4_K) {
+ (void)x_qh;
+
+ *x_ql = tile_x_ql_q4_K;
+ *x_dm = tile_x_dm_q4_K;
+ *x_sc = tile_x_sc_q4_K;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql,
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+ const int &k, const int &blocks_per_row) {
+ (void)x_qh;
+
+ GGML_SYCL_ASSUME(i_offset >= 0);
+ GGML_SYCL_ASSUME(i_offset < nwarps);
+ GGML_SYCL_ASSUME(k >= 0);
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
+
+ const int kbx = k / QI4_K; // == 0 if QK_K == 256
+ const int kqsx = k % QI4_K; // == k if QK_K == 256
+
+ const block_q4_K * bx0 = (const block_q4_K *) vx;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + i_offset;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
+
+ x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
+ const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
+ int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
+
+#if QK_K == 256
+ x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
+#else
+ x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
+#endif
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+ int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
+
+ const int * scales = (const int *) bxi->scales;
+
+ const int ksc = k % (WARP_SIZE/8);
+
+ // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
+ int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
+ scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
+
+ x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+ }
+}
+
+
+#define VDR_Q4_K_Q8_1_MMQ 8
+
+// contiguous u/y values
+static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_mmq(
+ const int *__restrict__ v, const int *__restrict__ u,
+ const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
+ const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
+ int sumi_d = 0;
+
+#pragma unroll
+ for (int j = 0; j < QI8_1; ++j) {
+ sumi_d = dpct::dp4a((v[j] >> (4 * i)) & 0x0F0F0F0F,
+ u[i * QI8_1 + j], sumi_d); // SIMD dot product
+ }
+
+ const sycl::float2 ds8f =
+ ds8[i].convert<float, sycl::rounding_mode::automatic>();
+
+ sumf_d += ds8f.x() * (sc[i] * sumi_d);
+ sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
+ }
+
+ const sycl::float2 dm4f =
+ dm4.convert<float, sycl::rounding_mode::automatic>();
+
+ return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
+}
+
+
+static __dpct_inline__ float vec_dot_q4_K_q8_1_mul_mat(
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+ const int &i, const int &j, const int &k) {
+ (void)x_qh;
+
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
+
+ const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
+ return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
+ x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q5_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+ int *tile_x_ql_q5_K, sycl::half2 *tile_x_dm_q5_K,
+ int *tile_x_sc_q5_K) {
+ (void)x_qh;
+
+ *x_ql = tile_x_ql_q5_K;
+ *x_dm = tile_x_dm_q5_K;
+ *x_sc = tile_x_sc_q5_K;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql,
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+ const int &k, const int &blocks_per_row) {
+ (void)x_qh;
+
+ GGML_SYCL_ASSUME(i_offset >= 0);
+ GGML_SYCL_ASSUME(i_offset < nwarps);
+ GGML_SYCL_ASSUME(k >= 0);
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
+
+ const int kbx = k / QI5_K; // == 0 if QK_K == 256
+ const int kqsx = k % QI5_K; // == k if QK_K == 256
+
+ const block_q5_K * bx0 = (const block_q5_K *) vx;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + i_offset;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
+ const int ky = QR5_K*kqsx;
+
+ const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+ const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+ const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+ const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
+ const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
+ const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
+
+ const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
+ const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
+
+ x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
+ x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
+ const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
+ int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
+
+#if QK_K == 256
+ x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
+#endif
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+ int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
+
+ const int * scales = (const int *) bxi->scales;
+
+ const int ksc = k % (WARP_SIZE/8);
+
+ // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
+ int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
+ scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
+
+ x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
+ }
+}
+
+#define VDR_Q5_K_Q8_1_MMQ 8
+
+// contiguous u/y values
+static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_mmq(
+ const int *__restrict__ v, const int *__restrict__ u,
+ const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
+ const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
+ int sumi_d = 0;
+
+#pragma unroll
+ for (int j = 0; j < QI8_1; ++j) {
+ sumi_d = dpct::dp4a(v[i * QI8_1 + j], u[i * QI8_1 + j],
+ sumi_d); // SIMD dot product
+ }
+
+ const sycl::float2 ds8f =
+ ds8[i].convert<float, sycl::rounding_mode::automatic>();
+
+ sumf_d += ds8f.x() * (sc[i] * sumi_d);
+ sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
+ }
+
+ const sycl::float2 dm4f =
+ dm4.convert<float, sycl::rounding_mode::automatic>();
+
+ return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
+}
+
+static __dpct_inline__ float vec_dot_q5_K_q8_1_mul_mat(
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+ const int &i, const int &j, const int &k) {
+ (void)x_qh;
+
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
+
+ const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
+ const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE;
+ return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
+ x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
+}
+
+template <int mmq_y>
+static __dpct_inline__ void
+allocate_tiles_q6_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
+ int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_sc) {
+ (void)x_qh;
+
+ *x_ql = tile_x_ql;
+ *x_dm = tile_x_dm;
+ *x_sc = tile_x_sc;
+}
+
+template <int mmq_y, int nwarps, bool need_check>
+static __dpct_inline__ void
+load_tiles_q6_K(const void *__restrict__ vx, int *__restrict__ x_ql,
+ sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
+ int *__restrict__ x_sc, const int &i_offset, const int &i_max,
+ const int &k, const int &blocks_per_row) {
+ (void)x_qh;
+
+ GGML_SYCL_ASSUME(i_offset >= 0);
+ GGML_SYCL_ASSUME(i_offset < nwarps);
+ GGML_SYCL_ASSUME(k >= 0);
+ GGML_SYCL_ASSUME(k < WARP_SIZE);
+
+ const int kbx = k / QI6_K; // == 0 if QK_K == 256
+ const int kqsx = k % QI6_K; // == k if QK_K == 256
+
+ const block_q6_K * bx0 = (const block_q6_K *) vx;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+ int i = i0 + i_offset;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
+ const int ky = QR6_K*kqsx;
+
+ const int ql = get_int_from_uint8(bxi->ql, kqsx);
+ const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+ const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+ const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
+ const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
+ const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
+
+ const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
+ const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
+
+ x_ql[i * (2 * WARP_SIZE + 1) + kq0] =
+ dpct::vectorized_binary<sycl::char4>(ql0 | qh0, 0x20202020,
+ dpct::sub_sat());
+ x_ql[i * (2 * WARP_SIZE + 1) + kq1] =
+ dpct::vectorized_binary<sycl::char4>(ql1 | qh1, 0x20202020,
+ dpct::sub_sat());
+ }
+
+ const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
+ const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
+ float * x_dmf = (float *) x_dm;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
+ int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
+
+ x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+ int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
+
+ if (need_check) {
+ i = sycl::min(i, i_max);
+ }
+
+ const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
+
+ x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
+ }
+}
+
+#define VDR_Q6_K_Q8_1_MMQ 8
+
+// contiguous u/y values
+static __dpct_inline__ float
+vec_dot_q6_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
+ const int8_t *__restrict__ sc, const float &d6,
+ const float *__restrict__ d8) {
+
+ float sumf_d = 0.0f;
+
+#pragma unroll
+ for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
+ sycl::int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
+
+#pragma unroll
+ for (int i = i0; i < i0 + 2; ++i) {
+ sumi_d.x() = dpct::dp4a(v[2 * i + 0], u[2 * i + 0],
+ sumi_d.x()); // SIMD dot product
+ sumi_d.x() = dpct::dp4a(v[2 * i + 1], u[2 * i + 1],
+ sumi_d.x()); // SIMD dot product
+
+ sumi_d.y() = dpct::dp4a(v[2 * i + 4], u[2 * i + 4],
+ sumi_d.y()); // SIMD dot product
+ sumi_d.y() = dpct::dp4a(v[2 * i + 5], u[2 * i + 5],
+ sumi_d.y()); // SIMD dot product
+ }
+
+ sumf_d += d8[i0 / 4] *
+ (sc[i0 / 2 + 0] * sumi_d.x() + sc[i0 / 2 + 1] * sumi_d.y());
+ }
+
+ return d6 * sumf_d;
+}
+
+static __dpct_inline__ float vec_dot_q6_K_q8_1_mul_mat(
+ const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
+ const int *__restrict__ x_qh, const int *__restrict__ x_sc,
+ const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
+ const int &i, const int &j, const int &k) {
+ (void)x_qh;
+
+ const float * x_dmf = (const float *) x_dm;
+ const float * y_df = (const float *) y_ds;
+
+ const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
+
+ const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k;
+ const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE;
+ return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
+}
+
+template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
+ int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,
+ vec_dot_q_mul_mat_sycl_t vec_dot>
+/*
+DPCT1110:8: The total declared local variable size in device function mul_mat_q
+exceeds 128 bytes and may cause high register pressure. Consult with your
+hardware vendor to find the total register size available and adjust the code,
+or use smaller sub-group size to avoid high register pressure.
+*/
+static __dpct_inline__ void
+mul_mat_q(const void *__restrict__ vx, const void *__restrict__ vy,
+ float *__restrict__ dst, const int ncols_x, const int nrows_x,
+ const int ncols_y, const int nrows_y, const int nrows_dst,
+ int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_qh,
+ int *tile_x_sc, const sycl::nd_item<3> &item_ct1, int *tile_y_qs,
+ sycl::half2 *tile_y_ds) {
+
+ const block_q_t * x = (const block_q_t *) vx;
+ const block_q8_1 * y = (const block_q8_1 *) vy;
+
+ const int blocks_per_row_x = ncols_x / qk;
+ const int blocks_per_col_y = nrows_y / QK8_1;
+ const int blocks_per_warp = WARP_SIZE / qi;
+
+ const int & ncols_dst = ncols_y;
+
+ const int row_dst_0 = item_ct1.get_group(2) * mmq_y;
+ const int & row_x_0 = row_dst_0;
+
+ const int col_dst_0 = item_ct1.get_group(1) * mmq_x;
+ const int & col_y_0 = col_dst_0;
+
+ float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
+
+ for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
+
+ load_tiles(x + row_x_0 * blocks_per_row_x + ib0, tile_x_ql, tile_x_dm,
+ tile_x_qh, tile_x_sc, item_ct1.get_local_id(1),
+ nrows_x - row_x_0 - 1, item_ct1.get_local_id(2),
+ blocks_per_row_x);
+
+#pragma unroll
+ for (int ir = 0; ir < qr; ++ir) {
+ const int kqs = ir * WARP_SIZE + item_ct1.get_local_id(2);
+ const int kbxd = kqs / QI8_1;
+
+#pragma unroll
+ for (int i = 0; i < mmq_x; i += nwarps) {
+ const int col_y_eff = dpct::min(
+ (unsigned int)(col_y_0 + item_ct1.get_local_id(1) + i),
+ ncols_y - 1); // to prevent out-of-bounds memory accesses
+
+ const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
+
+ const int index_y = (item_ct1.get_local_id(1) + i) * WARP_SIZE +
+ kqs % WARP_SIZE;
+ tile_y_qs[index_y] = get_int_from_int8_aligned(
+ by0->qs, item_ct1.get_local_id(2) % QI8_1);
+ }
+
+#pragma unroll
+ for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
+ const int ids =
+ (ids0 + item_ct1.get_local_id(1) * QI8_1 +
+ item_ct1.get_local_id(2) / (WARP_SIZE / QI8_1)) %
+ mmq_x;
+ const int kby = item_ct1.get_local_id(2) % (WARP_SIZE / QI8_1);
+ const int col_y_eff = sycl::min(col_y_0 + ids, ncols_y - 1);
+
+ // if the sum is not needed it's faster to transform the scale to f32 ahead of time
+ const sycl::half2 *dsi_src =
+ &y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) +
+ ir * (WARP_SIZE / QI8_1) + kby]
+ .ds;
+ sycl::half2 *dsi_dst =
+ &tile_y_ds[ids * (WARP_SIZE / QI8_1) + kby];
+ if (need_sum) {
+ *dsi_dst = *dsi_src;
+ } else {
+ float * dfi_dst = (float *) dsi_dst;
+ *dfi_dst = (*dsi_src)[0];
+ }
+ }
+
+ /*
+ DPCT1118:9: SYCL group functions and algorithms must be encountered
+ in converged control flow. You may need to adjust the code.
+ */
+ /*
+ DPCT1065:56: Consider replacing sycl::nd_item::barrier() with
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
+ better performance if there is no access to global memory.
+ */
+ item_ct1.barrier();
+
+// #pragma unroll // unrolling this loop causes too much register pressure
+ for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
+#pragma unroll
+ for (int j = 0; j < mmq_x; j += nwarps) {
+#pragma unroll
+ for (int i = 0; i < mmq_y; i += WARP_SIZE) {
+ sum[i / WARP_SIZE][j / nwarps] += vec_dot(
+ tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
+ tile_y_qs, tile_y_ds, item_ct1.get_local_id(2) + i,
+ item_ct1.get_local_id(1) + j, k);
+ }
+ }
+ }
+
+ /*
+ DPCT1118:10: SYCL group functions and algorithms must be encountered
+ in converged control flow. You may need to adjust the code.
+ */
+ /*
+ DPCT1065:57: Consider replacing sycl::nd_item::barrier() with
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
+ better performance if there is no access to global memory.
+ */
+ item_ct1.barrier();
+ }
+ }
+
+#pragma unroll
+ for (int j = 0; j < mmq_x; j += nwarps) {
+ const int col_dst = col_dst_0 + j + item_ct1.get_local_id(1);
+
+ if (col_dst >= ncols_dst) {
+ return;
+ }
+
+#pragma unroll
+ for (int i = 0; i < mmq_y; i += WARP_SIZE) {
+ const int row_dst = row_dst_0 + item_ct1.get_local_id(2) + i;
+
+ if (row_dst >= nrows_dst) {
+ continue;
+ }
+
+ dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps];
+ }
+ }
+}
+
+#define MMQ_X_Q4_0_RDNA2 64
+#define MMQ_Y_Q4_0_RDNA2 128
+#define NWARPS_Q4_0_RDNA2 8
+#define MMQ_X_Q4_0_RDNA1 64
+#define MMQ_Y_Q4_0_RDNA1 64
+#define NWARPS_Q4_0_RDNA1 8
+#if defined(SYCL_USE_XMX)
+#define MMQ_X_Q4_0_AMPERE 4
+#define MMQ_Y_Q4_0_AMPERE 32
+#define NWARPS_Q4_0_AMPERE 4
+#else
+#define MMQ_X_Q4_0_AMPERE 64
+#define MMQ_Y_Q4_0_AMPERE 128
+#define NWARPS_Q4_0_AMPERE 4
+#endif
+#define MMQ_X_Q4_0_PASCAL 64
+#define MMQ_Y_Q4_0_PASCAL 64
+#define NWARPS_Q4_0_PASCAL 8
+
+template <bool need_check> static void
+ mul_mat_q4_0(
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+ const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_0, float *tile_x_d_q4_0,
+ int *tile_y_qs, sycl::half2 *tile_y_ds) {
+ int * tile_x_ql = nullptr;
+ sycl::half2 *tile_x_dm = nullptr;
+ int * tile_x_qh = nullptr;
+ int * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+
+ const int mmq_x = MMQ_X_Q4_0_AMPERE;
+ const int mmq_y = MMQ_Y_Q4_0_AMPERE;
+ const int nwarps = NWARPS_Q4_0_AMPERE;
+ allocate_tiles_q4_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+ tile_x_qs_q4_0, tile_x_d_q4_0);
+ mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
+ load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ,
+ vec_dot_q4_0_q8_1_mul_mat>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define MMQ_X_Q4_1_RDNA2 64
+#define MMQ_Y_Q4_1_RDNA2 128
+#define NWARPS_Q4_1_RDNA2 8
+#define MMQ_X_Q4_1_RDNA1 64
+#define MMQ_Y_Q4_1_RDNA1 64
+#define NWARPS_Q4_1_RDNA1 8
+#if defined(SYCL_USE_XMX)
+#define MMQ_X_Q4_1_AMPERE 4
+#define MMQ_Y_Q4_1_AMPERE 32
+#define NWARPS_Q4_1_AMPERE 4
+#else
+#define MMQ_X_Q4_1_AMPERE 64
+#define MMQ_Y_Q4_1_AMPERE 128
+#define NWARPS_Q4_1_AMPERE 4
+#endif
+#define MMQ_X_Q4_1_PASCAL 64
+#define MMQ_Y_Q4_1_PASCAL 64
+#define NWARPS_Q4_1_PASCAL 8
+
+template <bool need_check> static void
+ mul_mat_q4_1(
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+ const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_1,
+ sycl::half2 *tile_x_dm_q4_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
+ int * tile_x_ql = nullptr;
+ sycl::half2 *tile_x_dm = nullptr;
+ int * tile_x_qh = nullptr;
+ int * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+ const int mmq_x = MMQ_X_Q4_1_AMPERE;
+ const int mmq_y = MMQ_Y_Q4_1_AMPERE;
+ const int nwarps = NWARPS_Q4_1_AMPERE;
+ allocate_tiles_q4_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+ tile_x_qs_q4_1, tile_x_dm_q4_1);
+ mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
+ load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ,
+ vec_dot_q4_1_q8_1_mul_mat>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define MMQ_X_Q5_0_RDNA2 64
+#define MMQ_Y_Q5_0_RDNA2 128
+#define NWARPS_Q5_0_RDNA2 8
+#define MMQ_X_Q5_0_RDNA1 64
+#define MMQ_Y_Q5_0_RDNA1 64
+#define NWARPS_Q5_0_RDNA1 8
+#if defined(SYCL_USE_XMX)
+#define MMQ_X_Q5_0_AMPERE 4
+#define MMQ_Y_Q5_0_AMPERE 32
+#define NWARPS_Q5_0_AMPERE 4
+#else
+#define MMQ_X_Q5_0_AMPERE 128
+#define MMQ_Y_Q5_0_AMPERE 64
+#define NWARPS_Q5_0_AMPERE 4
+#endif
+#define MMQ_X_Q5_0_PASCAL 64
+#define MMQ_Y_Q5_0_PASCAL 64
+#define NWARPS_Q5_0_PASCAL 8
+
+template <bool need_check> static void
+ mul_mat_q5_0(
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_0, float *tile_x_d_q5_0,
+ int *tile_y_qs, sycl::half2 *tile_y_ds) {
+ int * tile_x_ql = nullptr;
+ sycl::half2 *tile_x_dm = nullptr;
+ int * tile_x_qh = nullptr;
+ int * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+ const int mmq_x = MMQ_X_Q5_0_AMPERE;
+ const int mmq_y = MMQ_Y_Q5_0_AMPERE;
+ const int nwarps = NWARPS_Q5_0_AMPERE;
+ allocate_tiles_q5_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+ tile_x_ql_q5_0, tile_x_d_q5_0);
+ mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
+ load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ,
+ vec_dot_q5_0_q8_1_mul_mat>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define MMQ_X_Q5_1_RDNA2 64
+#define MMQ_Y_Q5_1_RDNA2 128
+#define NWARPS_Q5_1_RDNA2 8
+#define MMQ_X_Q5_1_RDNA1 64
+#define MMQ_Y_Q5_1_RDNA1 64
+#define NWARPS_Q5_1_RDNA1 8
+#if defined(SYCL_USE_XMX)
+#define MMQ_X_Q5_1_AMPERE 4
+#define MMQ_Y_Q5_1_AMPERE 32
+#define NWARPS_Q5_1_AMPERE 4
+#else
+#define MMQ_X_Q5_1_AMPERE 128
+#define MMQ_Y_Q5_1_AMPERE 64
+#define NWARPS_Q5_1_AMPERE 4
+#endif
+#define MMQ_X_Q5_1_PASCAL 64
+#define MMQ_Y_Q5_1_PASCAL 64
+#define NWARPS_Q5_1_PASCAL 8
+
+template <bool need_check> static void
+mul_mat_q5_1(
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_1,
+ sycl::half2 *tile_x_dm_q5_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
+ int * tile_x_ql = nullptr;
+ sycl::half2 *tile_x_dm = nullptr;
+ int * tile_x_qh = nullptr;
+ int * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+ const int mmq_x = MMQ_X_Q5_1_AMPERE;
+ const int mmq_y = MMQ_Y_Q5_1_AMPERE;
+ const int nwarps = NWARPS_Q5_1_AMPERE;
+ allocate_tiles_q5_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+ tile_x_ql_q5_1, tile_x_dm_q5_1);
+ mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
+ load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ,
+ vec_dot_q5_1_q8_1_mul_mat>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define MMQ_X_Q8_0_RDNA2 64
+#define MMQ_Y_Q8_0_RDNA2 128
+#define NWARPS_Q8_0_RDNA2 8
+#define MMQ_X_Q8_0_RDNA1 64
+#define MMQ_Y_Q8_0_RDNA1 64
+#define NWARPS_Q8_0_RDNA1 8
+#if defined(SYCL_USE_XMX)
+#define MMQ_X_Q8_0_AMPERE 4
+#define MMQ_Y_Q8_0_AMPERE 32
+#define NWARPS_Q8_0_AMPERE 4
+#else
+#define MMQ_X_Q8_0_AMPERE 128
+#define MMQ_Y_Q8_0_AMPERE 64
+#define NWARPS_Q8_0_AMPERE 4
+#endif
+#define MMQ_X_Q8_0_PASCAL 64
+#define MMQ_Y_Q8_0_PASCAL 64
+#define NWARPS_Q8_0_PASCAL 8
+
+template <bool need_check> static void
+ mul_mat_q8_0(
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+ const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q8_0, float *tile_x_d_q8_0,
+ int *tile_y_qs, sycl::half2 *tile_y_ds) {
+ int * tile_x_ql = nullptr;
+ sycl::half2 *tile_x_dm = nullptr;
+ int * tile_x_qh = nullptr;
+ int * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+ const int mmq_x = MMQ_X_Q8_0_AMPERE;
+ const int mmq_y = MMQ_Y_Q8_0_AMPERE;
+ const int nwarps = NWARPS_Q8_0_AMPERE;
+ allocate_tiles_q8_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+ tile_x_qs_q8_0, tile_x_d_q8_0);
+ mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
+ load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ,
+ vec_dot_q8_0_q8_1_mul_mat>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define MMQ_X_Q2_K_RDNA2 64
+#define MMQ_Y_Q2_K_RDNA2 128
+#define NWARPS_Q2_K_RDNA2 8
+#define MMQ_X_Q2_K_RDNA1 128
+#define MMQ_Y_Q2_K_RDNA1 32
+#define NWARPS_Q2_K_RDNA1 8
+#if defined(SYCL_USE_XMX)
+#define MMQ_X_Q2_K_AMPERE 4
+#define MMQ_Y_Q2_K_AMPERE 32
+#define NWARPS_Q2_K_AMPERE 4
+#else
+#define MMQ_X_Q2_K_AMPERE 64
+#define MMQ_Y_Q2_K_AMPERE 128
+#define NWARPS_Q2_K_AMPERE 4
+#endif
+#define MMQ_X_Q2_K_PASCAL 64
+#define MMQ_Y_Q2_K_PASCAL 64
+#define NWARPS_Q2_K_PASCAL 8
+
+template <bool need_check> static void
+mul_mat_q2_K(
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q2_K,
+ sycl::half2 *tile_x_dm_q2_K, int *tile_x_sc_q2_K, int *tile_y_qs,
+ sycl::half2 *tile_y_ds) {
+ int * tile_x_ql = nullptr;
+ sycl::half2 *tile_x_dm = nullptr;
+ int * tile_x_qh = nullptr;
+ int * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+ const int mmq_x = MMQ_X_Q2_K_AMPERE;
+ const int mmq_y = MMQ_Y_Q2_K_AMPERE;
+ const int nwarps = NWARPS_Q2_K_AMPERE;
+ allocate_tiles_q2_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+ tile_x_ql_q2_K, tile_x_dm_q2_K, tile_x_sc_q2_K);
+ mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
+ load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ,
+ vec_dot_q2_K_q8_1_mul_mat>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define MMQ_X_Q3_K_RDNA2 128
+#define MMQ_Y_Q3_K_RDNA2 64
+#define NWARPS_Q3_K_RDNA2 8
+#define MMQ_X_Q3_K_RDNA1 32
+#define MMQ_Y_Q3_K_RDNA1 128
+#define NWARPS_Q3_K_RDNA1 8
+#if defined(SYCL_USE_XMX)
+#define MMQ_X_Q3_K_AMPERE 4
+#define MMQ_Y_Q3_K_AMPERE 32
+#define NWARPS_Q3_K_AMPERE 4
+#else
+#define MMQ_X_Q3_K_AMPERE 128
+#define MMQ_Y_Q3_K_AMPERE 128
+#define NWARPS_Q3_K_AMPERE 4
+#endif
+#define MMQ_X_Q3_K_PASCAL 64
+#define MMQ_Y_Q3_K_PASCAL 64
+#define NWARPS_Q3_K_PASCAL 8
+
+template <bool need_check> static void
+mul_mat_q3_K(
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q3_K,
+ sycl::half2 *tile_x_dm_q3_K, int *tile_x_qh_q3_K, int *tile_x_sc_q3_K,
+ int *tile_y_qs, sycl::half2 *tile_y_ds) {
+ int * tile_x_ql = nullptr;
+ sycl::half2 *tile_x_dm = nullptr;
+ int * tile_x_qh = nullptr;
+ int * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+ const int mmq_x = MMQ_X_Q3_K_AMPERE;
+ const int mmq_y = MMQ_Y_Q3_K_AMPERE;
+ const int nwarps = NWARPS_Q3_K_AMPERE;
+ allocate_tiles_q3_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+ tile_x_ql_q3_K, tile_x_dm_q3_K, tile_x_qh_q3_K,
+ tile_x_sc_q3_K);
+ mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
+ load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ,
+ vec_dot_q3_K_q8_1_mul_mat>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define MMQ_X_Q4_K_RDNA2 64
+#define MMQ_Y_Q4_K_RDNA2 128
+#define NWARPS_Q4_K_RDNA2 8
+#define MMQ_X_Q4_K_RDNA1 32
+#define MMQ_Y_Q4_K_RDNA1 64
+#define NWARPS_Q4_K_RDNA1 8
+#if defined(SYCL_USE_XMX)
+#define MMQ_X_Q4_K_AMPERE 4
+#define MMQ_Y_Q4_K_AMPERE 32
+#define NWARPS_Q4_K_AMPERE 4
+#else
+#define MMQ_X_Q4_K_AMPERE 64
+#define MMQ_Y_Q4_K_AMPERE 128
+#define NWARPS_Q4_K_AMPERE 4
+#endif
+#define MMQ_X_Q4_K_PASCAL 64
+#define MMQ_Y_Q4_K_PASCAL 64
+#define NWARPS_Q4_K_PASCAL 8
+
+template <bool need_check> static void
+ mul_mat_q4_K(
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q4_K,
+ sycl::half2 *tile_x_dm_q4_K, int *tile_x_sc_q4_K, int *tile_y_qs,
+ sycl::half2 *tile_y_ds) {
+ int * tile_x_ql = nullptr;
+ sycl::half2 *tile_x_dm = nullptr;
+ int * tile_x_qh = nullptr;
+ int * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+ const int mmq_x = MMQ_X_Q4_K_AMPERE;
+ const int mmq_y = MMQ_Y_Q4_K_AMPERE;
+ const int nwarps = NWARPS_Q4_K_AMPERE;
+ allocate_tiles_q4_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+ tile_x_ql_q4_K, tile_x_dm_q4_K, tile_x_sc_q4_K);
+ mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
+ load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ,
+ vec_dot_q4_K_q8_1_mul_mat>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define MMQ_X_Q5_K_RDNA2 64
+#define MMQ_Y_Q5_K_RDNA2 128
+#define NWARPS_Q5_K_RDNA2 8
+#define MMQ_X_Q5_K_RDNA1 32
+#define MMQ_Y_Q5_K_RDNA1 64
+#define NWARPS_Q5_K_RDNA1 8
+#if defined(SYCL_USE_XMX)
+#define MMQ_X_Q5_K_AMPERE 4
+#define MMQ_Y_Q5_K_AMPERE 32
+#define NWARPS_Q5_K_AMPERE 4
+#else
+#define MMQ_X_Q5_K_AMPERE 64
+#define MMQ_Y_Q5_K_AMPERE 128
+#define NWARPS_Q5_K_AMPERE 4
+#endif
+#define MMQ_X_Q5_K_PASCAL 64
+#define MMQ_Y_Q5_K_PASCAL 64
+#define NWARPS_Q5_K_PASCAL 8
+
+template <bool need_check> static void
+mul_mat_q5_K(
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_K,
+ sycl::half2 *tile_x_dm_q5_K, int *tile_x_sc_q5_K, int *tile_y_qs,
+ sycl::half2 *tile_y_ds) {
+ int * tile_x_ql = nullptr;
+ sycl::half2 *tile_x_dm = nullptr;
+ int * tile_x_qh = nullptr;
+ int * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+ const int mmq_x = MMQ_X_Q5_K_AMPERE;
+ const int mmq_y = MMQ_Y_Q5_K_AMPERE;
+ const int nwarps = NWARPS_Q5_K_AMPERE;
+ allocate_tiles_q5_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+ tile_x_ql_q5_K, tile_x_dm_q5_K, tile_x_sc_q5_K);
+ mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
+ load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ,
+ vec_dot_q5_K_q8_1_mul_mat>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+#define MMQ_X_Q6_K_RDNA2 64
+#define MMQ_Y_Q6_K_RDNA2 128
+#define NWARPS_Q6_K_RDNA2 8
+#define MMQ_X_Q6_K_RDNA1 32
+#define MMQ_Y_Q6_K_RDNA1 64
+#define NWARPS_Q6_K_RDNA1 8
+#if defined(SYCL_USE_XMX)
+#define MMQ_X_Q6_K_AMPERE 4
+#define MMQ_Y_Q6_K_AMPERE 32
+#define NWARPS_Q6_K_AMPERE 4
+#else
+#define MMQ_X_Q6_K_AMPERE 64
+#define MMQ_Y_Q6_K_AMPERE 64
+#define NWARPS_Q6_K_AMPERE 4
+#endif
+#define MMQ_X_Q6_K_PASCAL 64
+#define MMQ_Y_Q6_K_PASCAL 64
+#define NWARPS_Q6_K_PASCAL 8
+
+template <bool need_check> static void
+ mul_mat_q6_K(
+ const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
+ const sycl::nd_item<3> &item_ct1, int *tile_x_ql, sycl::half2 *tile_x_dm,
+ int *tile_x_sc, int *tile_y_qs, sycl::half2 *tile_y_ds) {
+ // int * tile_x_ql = nullptr;
+ // sycl::half2 *tile_x_dm = nullptr;
+ int * tile_x_qh = nullptr;
+ // int * tile_x_sc = nullptr;
+
+//sycl_todo: change according to hardware
+ const int mmq_x = MMQ_X_Q6_K_AMPERE;
+ const int mmq_y = MMQ_Y_Q6_K_AMPERE;
+ const int nwarps = NWARPS_Q6_K_AMPERE;
+ allocate_tiles_q6_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
+ tile_x_ql, tile_x_dm, tile_x_sc);
+ mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
+ load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ,
+ vec_dot_q6_K_q8_1_mul_mat>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
+ tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
+}
+
+static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols_x,
+ const int nrows_x, const int ncols_y,
+ const int nrows_y, const int nrows_dst,
+ dpct::queue_ptr stream) try {
+
+ int id;
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(id = get_current_device_id()));
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+ int mmq_x, mmq_y, nwarps;
+ if (compute_capability >= VER_GEN13) {
+ mmq_x = MMQ_X_Q4_0_RDNA2;
+ mmq_y = MMQ_Y_Q4_0_RDNA2;
+ nwarps = NWARPS_Q4_0_RDNA2;
+ } else if (compute_capability >= VER_GEN12) {
+ mmq_x = MMQ_X_Q4_0_RDNA1;
+ mmq_y = MMQ_Y_Q4_0_RDNA1;
+ nwarps = NWARPS_Q4_0_RDNA1;
+ } else if (compute_capability >= VER_GEN9) {
+ mmq_x = MMQ_X_Q4_0_AMPERE;
+ mmq_y = MMQ_Y_Q4_0_AMPERE;
+ nwarps = NWARPS_Q4_0_AMPERE;
+ } else if (compute_capability >= VER_4VEC) {
+ mmq_x = MMQ_X_Q4_0_PASCAL;
+ mmq_y = MMQ_Y_Q4_0_PASCAL;
+ nwarps = NWARPS_Q4_0_PASCAL;
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ /*
+ DPCT1049:20: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
+ cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q4_0<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_qs_q4_0_acc_ct1),
+ get_pointer(tile_x_d_q4_0_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ } else {
+ const bool need_check = true;
+ /*
+ DPCT1049:21: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
+ cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q4_0<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_qs_q4_0_acc_ct1),
+ get_pointer(tile_x_d_q4_0_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols_x,
+ const int nrows_x, const int ncols_y,
+ const int nrows_y, const int nrows_dst,
+ dpct::queue_ptr stream) try {
+
+ int id;
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(id = get_current_device_id()));
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+ int mmq_x, mmq_y, nwarps;
+ if (compute_capability >= VER_GEN13) {
+ mmq_x = MMQ_X_Q4_1_RDNA2;
+ mmq_y = MMQ_Y_Q4_1_RDNA2;
+ nwarps = NWARPS_Q4_1_RDNA2;
+ } else if (compute_capability >= VER_GEN12) {
+ mmq_x = MMQ_X_Q4_1_RDNA1;
+ mmq_y = MMQ_Y_Q4_1_RDNA1;
+ nwarps = NWARPS_Q4_1_RDNA1;
+ } else if (compute_capability >= VER_GEN9) {
+ mmq_x = MMQ_X_Q4_1_AMPERE;
+ mmq_y = MMQ_Y_Q4_1_AMPERE;
+ nwarps = NWARPS_Q4_1_AMPERE;
+ } else if (compute_capability >= VER_4VEC) {
+ mmq_x = MMQ_X_Q4_1_PASCAL;
+ mmq_y = MMQ_Y_Q4_1_PASCAL;
+ nwarps = NWARPS_Q4_1_PASCAL;
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ /*
+ DPCT1049:22: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
+ cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q4_1<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_qs_q4_1_acc_ct1),
+ get_pointer(tile_x_dm_q4_1_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ } else {
+ const bool need_check = true;
+ /*
+ DPCT1049:23: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
+ cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q4_1<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_qs_q4_1_acc_ct1),
+ get_pointer(tile_x_dm_q4_1_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols_x,
+ const int nrows_x, const int ncols_y,
+ const int nrows_y, const int nrows_dst,
+ dpct::queue_ptr stream) try {
+
+ int id;
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(id = get_current_device_id()));
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+ int mmq_x, mmq_y, nwarps;
+ if (compute_capability >= VER_GEN13) {
+ mmq_x = MMQ_X_Q5_0_RDNA2;
+ mmq_y = MMQ_Y_Q5_0_RDNA2;
+ nwarps = NWARPS_Q5_0_RDNA2;
+ } else if (compute_capability >= VER_GEN12) {
+ mmq_x = MMQ_X_Q5_0_RDNA1;
+ mmq_y = MMQ_Y_Q5_0_RDNA1;
+ nwarps = NWARPS_Q5_0_RDNA1;
+ } else if (compute_capability >= VER_GEN9) {
+ mmq_x = MMQ_X_Q5_0_AMPERE;
+ mmq_y = MMQ_Y_Q5_0_AMPERE;
+ nwarps = NWARPS_Q5_0_AMPERE;
+ } else if (compute_capability >= VER_4VEC) {
+ mmq_x = MMQ_X_Q5_0_PASCAL;
+ mmq_y = MMQ_Y_Q5_0_PASCAL;
+ nwarps = NWARPS_Q5_0_PASCAL;
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ /*
+ DPCT1049:24: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
+ cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q5_0<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_ql_q5_0_acc_ct1),
+ get_pointer(tile_x_d_q5_0_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ } else {
+ const bool need_check = true;
+ /*
+ DPCT1049:25: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
+ cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q5_0<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_ql_q5_0_acc_ct1),
+ get_pointer(tile_x_d_q5_0_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols_x,
+ const int nrows_x, const int ncols_y,
+ const int nrows_y, const int nrows_dst,
+ dpct::queue_ptr stream) try {
+
+ int id;
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(id = get_current_device_id()));
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+ int mmq_x, mmq_y, nwarps;
+ if (compute_capability >= VER_GEN13) {
+ mmq_x = MMQ_X_Q5_1_RDNA2;
+ mmq_y = MMQ_Y_Q5_1_RDNA2;
+ nwarps = NWARPS_Q5_1_RDNA2;
+ } else if (compute_capability >= VER_GEN12) {
+ mmq_x = MMQ_X_Q5_1_RDNA1;
+ mmq_y = MMQ_Y_Q5_1_RDNA1;
+ nwarps = NWARPS_Q5_1_RDNA1;
+ } else if (compute_capability >= VER_GEN9) {
+ mmq_x = MMQ_X_Q5_1_AMPERE;
+ mmq_y = MMQ_Y_Q5_1_AMPERE;
+ nwarps = NWARPS_Q5_1_AMPERE;
+ } else if (compute_capability >= VER_4VEC) {
+ mmq_x = MMQ_X_Q5_1_PASCAL;
+ mmq_y = MMQ_Y_Q5_1_PASCAL;
+ nwarps = NWARPS_Q5_1_PASCAL;
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ /*
+ DPCT1049:26: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
+ cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q5_1<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_ql_q5_1_acc_ct1),
+ get_pointer(tile_x_dm_q5_1_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ } else {
+ const bool need_check = true;
+ /*
+ DPCT1049:27: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
+ cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q5_1<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_ql_q5_1_acc_ct1),
+ get_pointer(tile_x_dm_q5_1_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols_x,
+ const int nrows_x, const int ncols_y,
+ const int nrows_y, const int nrows_dst,
+ dpct::queue_ptr stream) try {
+
+ int id;
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(id = get_current_device_id()));
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+ int mmq_x, mmq_y, nwarps;
+ if (compute_capability >= VER_GEN13) {
+ mmq_x = MMQ_X_Q8_0_RDNA2;
+ mmq_y = MMQ_Y_Q8_0_RDNA2;
+ nwarps = NWARPS_Q8_0_RDNA2;
+ } else if (compute_capability >= VER_GEN12) {
+ mmq_x = MMQ_X_Q8_0_RDNA1;
+ mmq_y = MMQ_Y_Q8_0_RDNA1;
+ nwarps = NWARPS_Q8_0_RDNA1;
+ } else if (compute_capability >= VER_GEN9) {
+ mmq_x = MMQ_X_Q8_0_AMPERE;
+ mmq_y = MMQ_Y_Q8_0_AMPERE;
+ nwarps = NWARPS_Q8_0_AMPERE;
+ } else if (compute_capability >= VER_4VEC) {
+ mmq_x = MMQ_X_Q8_0_PASCAL;
+ mmq_y = MMQ_Y_Q8_0_PASCAL;
+ nwarps = NWARPS_Q8_0_PASCAL;
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ /*
+ DPCT1049:28: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
+ cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q8_0<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_qs_q8_0_acc_ct1),
+ get_pointer(tile_x_d_q8_0_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ } else {
+ const bool need_check = true;
+ /*
+ DPCT1049:29: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
+ cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q8_0<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_qs_q8_0_acc_ct1),
+ get_pointer(tile_x_d_q8_0_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols_x,
+ const int nrows_x, const int ncols_y,
+ const int nrows_y, const int nrows_dst,
+ dpct::queue_ptr stream) try {
+
+ int id;
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(id = get_current_device_id()));
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+ int mmq_x, mmq_y, nwarps;
+ if (compute_capability >= VER_GEN13) {
+ mmq_x = MMQ_X_Q2_K_RDNA2;
+ mmq_y = MMQ_Y_Q2_K_RDNA2;
+ nwarps = NWARPS_Q2_K_RDNA2;
+ } else if (compute_capability >= VER_GEN12) {
+ mmq_x = MMQ_X_Q2_K_RDNA1;
+ mmq_y = MMQ_Y_Q2_K_RDNA1;
+ nwarps = NWARPS_Q2_K_RDNA1;
+ } else if (compute_capability >= VER_GEN9) {
+ mmq_x = MMQ_X_Q2_K_AMPERE;
+ mmq_y = MMQ_Y_Q2_K_AMPERE;
+ nwarps = NWARPS_Q2_K_AMPERE;
+ } else if (compute_capability >= VER_4VEC) {
+ mmq_x = MMQ_X_Q2_K_PASCAL;
+ mmq_y = MMQ_Y_Q2_K_PASCAL;
+ nwarps = NWARPS_Q2_K_PASCAL;
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ /*
+ DPCT1049:30: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
+ cgh);
+ sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q2_K<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_ql_q2_K_acc_ct1),
+ get_pointer(tile_x_dm_q2_K_acc_ct1),
+ get_pointer(tile_x_sc_q2_K_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ } else {
+ const bool need_check = true;
+ /*
+ DPCT1049:31: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
+ cgh);
+ sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q2_K<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_ql_q2_K_acc_ct1),
+ get_pointer(tile_x_dm_q2_K_acc_ct1),
+ get_pointer(tile_x_sc_q2_K_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols_x,
+ const int nrows_x, const int ncols_y,
+ const int nrows_y, const int nrows_dst,
+ dpct::queue_ptr stream) try {
+
+#if QK_K == 256
+
+ int id;
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(id = get_current_device_id()));
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+ int mmq_x, mmq_y, nwarps;
+ if (compute_capability >= VER_GEN13) {
+ mmq_x = MMQ_X_Q3_K_RDNA2;
+ mmq_y = MMQ_Y_Q3_K_RDNA2;
+ nwarps = NWARPS_Q3_K_RDNA2;
+ } else if (compute_capability >= VER_GEN12) {
+ mmq_x = MMQ_X_Q3_K_RDNA1;
+ mmq_y = MMQ_Y_Q3_K_RDNA1;
+ nwarps = NWARPS_Q3_K_RDNA1;
+ } else if (compute_capability >= VER_GEN9) {
+ mmq_x = MMQ_X_Q3_K_AMPERE;
+ mmq_y = MMQ_Y_Q3_K_AMPERE;
+ nwarps = NWARPS_Q3_K_AMPERE;
+ } else if (compute_capability >= VER_4VEC) {
+ mmq_x = MMQ_X_Q3_K_PASCAL;
+ mmq_y = MMQ_Y_Q3_K_PASCAL;
+ nwarps = NWARPS_Q3_K_PASCAL;
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ /*
+ DPCT1049:32: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
+ cgh);
+ sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
+ sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q3_K<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_ql_q3_K_acc_ct1),
+ get_pointer(tile_x_dm_q3_K_acc_ct1),
+ get_pointer(tile_x_qh_q3_K_acc_ct1),
+ get_pointer(tile_x_sc_q3_K_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ } else {
+ const bool need_check = true;
+ /*
+ DPCT1049:33: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
+ cgh);
+ sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
+ sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q3_K<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_ql_q3_K_acc_ct1),
+ get_pointer(tile_x_dm_q3_K_acc_ct1),
+ get_pointer(tile_x_qh_q3_K_acc_ct1),
+ get_pointer(tile_x_sc_q3_K_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ }
+#endif
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols_x,
+ const int nrows_x, const int ncols_y,
+ const int nrows_y, const int nrows_dst,
+ dpct::queue_ptr stream) try {
+
+ int id;
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(id = get_current_device_id()));
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+ int mmq_x, mmq_y, nwarps;
+ if (compute_capability >= VER_GEN13) {
+ mmq_x = MMQ_X_Q4_K_RDNA2;
+ mmq_y = MMQ_Y_Q4_K_RDNA2;
+ nwarps = NWARPS_Q4_K_RDNA2;
+ } else if (compute_capability >= VER_GEN12) {
+ mmq_x = MMQ_X_Q4_K_RDNA1;
+ mmq_y = MMQ_Y_Q4_K_RDNA1;
+ nwarps = NWARPS_Q4_K_RDNA1;
+ } else if (compute_capability >= VER_GEN9) {
+ mmq_x = MMQ_X_Q4_K_AMPERE;
+ mmq_y = MMQ_Y_Q4_K_AMPERE;
+ nwarps = NWARPS_Q4_K_AMPERE;
+ } else if (compute_capability >= VER_4VEC) {
+ mmq_x = MMQ_X_Q4_K_PASCAL;
+ mmq_y = MMQ_Y_Q4_K_PASCAL;
+ nwarps = NWARPS_Q4_K_PASCAL;
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ /*
+ DPCT1049:34: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
+ cgh);
+ sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q4_K<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_ql_q4_K_acc_ct1),
+ get_pointer(tile_x_dm_q4_K_acc_ct1),
+ get_pointer(tile_x_sc_q4_K_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ } else {
+ const bool need_check = true;
+ /*
+ DPCT1049:35: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
+ cgh);
+ sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q4_K<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_ql_q4_K_acc_ct1),
+ get_pointer(tile_x_dm_q4_K_acc_ct1),
+ get_pointer(tile_x_sc_q4_K_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols_x,
+ const int nrows_x, const int ncols_y,
+ const int nrows_y, const int nrows_dst,
+ dpct::queue_ptr stream) try {
+
+ int id;
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(id = get_current_device_id()));
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+ int mmq_x, mmq_y, nwarps;
+ if (compute_capability >= VER_GEN13) {
+ mmq_x = MMQ_X_Q5_K_RDNA2;
+ mmq_y = MMQ_Y_Q5_K_RDNA2;
+ nwarps = NWARPS_Q5_K_RDNA2;
+ } else if (compute_capability >= VER_GEN12) {
+ mmq_x = MMQ_X_Q5_K_RDNA1;
+ mmq_y = MMQ_Y_Q5_K_RDNA1;
+ nwarps = NWARPS_Q5_K_RDNA1;
+ } else if (compute_capability >= VER_GEN9) {
+ mmq_x = MMQ_X_Q5_K_AMPERE;
+ mmq_y = MMQ_Y_Q5_K_AMPERE;
+ nwarps = NWARPS_Q5_K_AMPERE;
+ } else if (compute_capability >= VER_4VEC) {
+ mmq_x = MMQ_X_Q5_K_PASCAL;
+ mmq_y = MMQ_Y_Q5_K_PASCAL;
+ nwarps = NWARPS_Q5_K_PASCAL;
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ /*
+ DPCT1049:36: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
+ cgh);
+ sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q5_K<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_ql_q5_K_acc_ct1),
+ get_pointer(tile_x_dm_q5_K_acc_ct1),
+ get_pointer(tile_x_sc_q5_K_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ } else {
+ const bool need_check = true;
+ /*
+ DPCT1049:37: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
+ cgh);
+ sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q5_K<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_ql_q5_K_acc_ct1),
+ get_pointer(tile_x_dm_q5_K_acc_ct1),
+ get_pointer(tile_x_sc_q5_K_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols_x,
+ const int nrows_x, const int ncols_y,
+ const int nrows_y, const int nrows_dst,
+ dpct::queue_ptr stream) try {
+
+ int id;
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(id = get_current_device_id()));
+ const int compute_capability = ggml_sycl_info().devices[id].cc;
+
+ int mmq_x, mmq_y, nwarps;
+ if (compute_capability >= VER_GEN13) {
+ mmq_x = MMQ_X_Q6_K_RDNA2;
+ mmq_y = MMQ_Y_Q6_K_RDNA2;
+ nwarps = NWARPS_Q6_K_RDNA2;
+ } else if (compute_capability >= VER_GEN12) {
+ mmq_x = MMQ_X_Q6_K_RDNA1;
+ mmq_y = MMQ_Y_Q6_K_RDNA1;
+ nwarps = NWARPS_Q6_K_RDNA1;
+ } else if (compute_capability >= VER_GEN9) {
+ mmq_x = MMQ_X_Q6_K_AMPERE;
+ mmq_y = MMQ_Y_Q6_K_AMPERE;
+ nwarps = NWARPS_Q6_K_AMPERE;
+ } else if (compute_capability >= VER_4VEC) {
+ mmq_x = MMQ_X_Q6_K_PASCAL;
+ mmq_y = MMQ_Y_Q6_K_PASCAL;
+ nwarps = NWARPS_Q6_K_PASCAL;
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const sycl::range<3> block_nums(1, block_num_y, block_num_x);
+ const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ /*
+ DPCT1049:38: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
+ cgh);
+ sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q6_K<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_ql_acc_ct1),
+ get_pointer(tile_x_dm_acc_ct1),
+ get_pointer(tile_x_sc_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ } else {
+ const bool need_check = true;
+ /*
+ DPCT1049:39: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ {
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
+ sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
+ cgh);
+ sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
+ sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
+ sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE), cgh);
+ sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
+ sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ mul_mat_q6_K<need_check>(
+ vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
+ nrows_dst, item_ct1,
+ get_pointer(tile_x_ql_acc_ct1),
+ get_pointer(tile_x_dm_acc_ct1),
+ get_pointer(tile_x_sc_acc_ct1),
+ get_pointer(tile_y_qs_acc_ct1),
+ get_pointer(tile_y_ds_acc_ct1));
+ });
+ });
+ }
+ }
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
+
+void ggml_sycl_op_mul_mat_q(
+ ggml_backend_sycl_context & ctx,
+ const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+ const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
+ float *dst_dd_i, const int64_t row_low, const int64_t row_high,
+ const int64_t src1_ncols, const int64_t src1_padded_row_size,
+ const dpct::queue_ptr &stream) try {
+
+ const int64_t ne00 = src0->ne[0];
+
+ const int64_t ne10 = src1->ne[0];
+ GGML_ASSERT(ne10 % QK8_1 == 0);
+
+ const int64_t ne0 = dst->ne[0];
+
+ const int64_t row_diff = row_high - row_low;
+
+ int device_id;
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(device_id = get_current_device_id()));
+
+ // the main device has a larger memory buffer to hold the results from all GPUs
+ // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
+ const int64_t nrows_dst = device_id == ctx.device ? ne0 : row_diff;
+
+ switch (src0->type) {
+ case GGML_TYPE_Q4_0:
+ ggml_mul_mat_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q4_1:
+ ggml_mul_mat_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q5_0:
+ ggml_mul_mat_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q5_1:
+ ggml_mul_mat_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q8_0:
+ ggml_mul_mat_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q2_K:
+ ggml_mul_mat_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q3_K:
+ ggml_mul_mat_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q4_K:
+ ggml_mul_mat_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q5_K:
+ ggml_mul_mat_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+ break;
+ case GGML_TYPE_Q6_K:
+ ggml_mul_mat_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+
+ (void) src1;
+ (void) dst;
+ (void) src1_ddf_i;
+}
+catch (sycl::exception const &exc) {
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+ << ", line:" << __LINE__ << std::endl;
+ std::exit(1);
+}
diff --git a/ggml/src/ggml-sycl/mmq.hpp b/ggml/src/ggml-sycl/mmq.hpp
new file mode 100644
index 00000000..3f5297aa
--- /dev/null
+++ b/ggml/src/ggml-sycl/mmq.hpp
@@ -0,0 +1,33 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_MMQ_HPP
+#define GGML_SYCL_MMQ_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_mul_mat_q(
+ ggml_backend_sycl_context & ctx,
+ const ggml_tensor* src0,
+ const ggml_tensor* src1,
+ ggml_tensor* dst,
+ const char* src0_dd_i,
+ const float* src1_ddf_i,
+ const char* src1_ddq_i,
+ float* dst_dd_i,
+ const int64_t row_low,
+ const int64_t row_high,
+ const int64_t src1_ncols,
+ const int64_t src1_padded_row_size,
+ const dpct::queue_ptr& stream);
+
+#endif // GGML_SYCL_MMQ_HPP
diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp
new file mode 100644
index 00000000..3fbc4dd6
--- /dev/null
+++ b/ggml/src/ggml-sycl/mmvq.cpp
@@ -0,0 +1,1027 @@
+#include "mmvq.hpp"
+#include "vecdotq.hpp"
+
+
+template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
+static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
+ const sycl::nd_item<3> &item_ct1) {
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+
+ if (row >= nrows) {
+ return;
+ }
+
+ const int blocks_per_row = ncols / qk;
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+ float tmp = 0.0f;
+
+ const block_q_t * x = (const block_q_t *) vx;
+ const block_q8_1 * y = (const block_q8_1 *) vy;
+
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+ i += blocks_per_warp) {
+ const int ibx = row*blocks_per_row + i; // x block index
+
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+ const int iqs =
+ vdr *
+ (item_ct1.get_local_id(2) %
+ (qi / vdr)); // x block quant index when casting the quants to int
+
+ tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
+ }
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (item_ct1.get_local_id(2) == 0) {
+ dst[row] = tmp;
+ }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
+ const void *__restrict__ vy,
+ float *__restrict__ dst, const int ncols,
+ const int nrows,
+ const sycl::nd_item<3> &item_ct1) {
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+
+ if (row >= nrows) {
+ return;
+ }
+
+ const int blocks_per_row = ncols / qk;
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+ float tmp = 0.0f;
+
+ const block_q_t * x = (const block_q_t *) vx;
+ const block_q8_1 * y = (const block_q8_1 *) vy;
+
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+ i += blocks_per_warp) {
+ const int ibx = row*blocks_per_row + i; // x block index
+
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+ const int iqs =
+ vdr *
+ (item_ct1.get_local_id(2) %
+ (qi / vdr)); // x block quant index when casting the quants to int
+
+ tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
+ }
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (item_ct1.get_local_id(2) == 0) {
+ dst[row] = tmp;
+ }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
+ const void *__restrict__ vy,
+ float *__restrict__ dst, const int ncols,
+ const int nrows,
+ const sycl::nd_item<3> &item_ct1) {
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+
+ if (row >= nrows) {
+ return;
+ }
+
+ const int blocks_per_row = ncols / qk;
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+ float tmp = 0.0f;
+
+ const block_q_t * x = (const block_q_t *) vx;
+ const block_q8_1 * y = (const block_q8_1 *) vy;
+
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+ i += blocks_per_warp) {
+ const int ibx = row*blocks_per_row + i; // x block index
+
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+ const int iqs =
+ vdr *
+ (item_ct1.get_local_id(2) %
+ (qi / vdr)); // x block quant index when casting the quants to int
+
+ tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid, ksigns64);
+ }
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (item_ct1.get_local_id(2) == 0) {
+ dst[row] = tmp;
+ }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
+ const void *__restrict__ vy,
+ float *__restrict__ dst, const int ncols,
+ const int nrows,
+ const sycl::nd_item<3> &item_ct1) {
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+
+ if (row >= nrows) {
+ return;
+ }
+
+ const int blocks_per_row = ncols / qk;
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+ float tmp = 0.0f;
+
+ const block_q_t * x = (const block_q_t *) vx;
+ const block_q8_1 * y = (const block_q8_1 *) vy;
+
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+ i += blocks_per_warp) {
+ const int ibx = row*blocks_per_row + i; // x block index
+
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+ const int iqs =
+ vdr *
+ (item_ct1.get_local_id(2) %
+ (qi / vdr)); // x block quant index when casting the quants to int
+
+ tmp += vec_dot_iq2_s_q8_1(&x[ibx], &y[iby], iqs);
+ }
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (item_ct1.get_local_id(2) == 0) {
+ dst[row] = tmp;
+ }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
+ const void *__restrict__ vy,
+ float *__restrict__ dst, const int ncols,
+ const int nrows,
+ const sycl::nd_item<3> &item_ct1) {
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+
+ if (row >= nrows) {
+ return;
+ }
+
+ const int blocks_per_row = ncols / qk;
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+ float tmp = 0.0f;
+
+ const block_q_t * x = (const block_q_t *) vx;
+ const block_q8_1 * y = (const block_q8_1 *) vy;
+
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+ i += blocks_per_warp) {
+ const int ibx = row*blocks_per_row + i; // x block index
+
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+ const int iqs =
+ vdr *
+ (item_ct1.get_local_id(2) %
+ (qi / vdr)); // x block quant index when casting the quants to int
+
+ tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid, ksigns64);
+ }
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (item_ct1.get_local_id(2) == 0) {
+ dst[row] = tmp;
+ }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
+ const void *__restrict__ vy,
+ float *__restrict__ dst, const int ncols,
+ const int nrows,
+ const sycl::nd_item<3> &item_ct1) {
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+
+ if (row >= nrows) {
+ return;
+ }
+
+ const int blocks_per_row = ncols / qk;
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+ float tmp = 0.0f;
+
+ const block_q_t * x = (const block_q_t *) vx;
+ const block_q8_1 * y = (const block_q8_1 *) vy;
+
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+ i += blocks_per_warp) {
+ const int ibx = row*blocks_per_row + i; // x block index
+
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+ const int iqs =
+ vdr *
+ (item_ct1.get_local_id(2) %
+ (qi / vdr)); // x block quant index when casting the quants to int
+
+ tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid);
+ }
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (item_ct1.get_local_id(2) == 0) {
+ dst[row] = tmp;
+ }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
+ const void *__restrict__ vy,
+ float *__restrict__ dst, const int ncols,
+ const int nrows,
+ const sycl::nd_item<3> &item_ct1) {
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+
+ if (row >= nrows) {
+ return;
+ }
+
+ const int blocks_per_row = ncols / qk;
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+ float tmp = 0.0f;
+
+ const block_q_t * x = (const block_q_t *) vx;
+ const block_q8_1 * y = (const block_q8_1 *) vy;
+
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+ i += blocks_per_warp) {
+ const int ibx = row*blocks_per_row + i; // x block index
+
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+ const int iqs =
+ vdr *
+ (item_ct1.get_local_id(2) %
+ (qi / vdr)); // x block quant index when casting the quants to int
+
+ tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_gpu);
+ }
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (item_ct1.get_local_id(2) == 0) {
+ dst[row] = tmp;
+ }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
+ const void *__restrict__ vy,
+ float *__restrict__ dst, const int ncols,
+ const int nrows,
+ const sycl::nd_item<3> &item_ct1) {
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+
+ if (row >= nrows) {
+ return;
+ }
+
+ const int blocks_per_row = ncols / qk;
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+ float tmp = 0.0f;
+
+ const block_q_t * x = (const block_q_t *) vx;
+ const block_q8_1 * y = (const block_q8_1 *) vy;
+
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+ i += blocks_per_warp) {
+ const int ibx = row*blocks_per_row + i; // x block index
+
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+ const int iqs =
+ vdr *
+ (item_ct1.get_local_id(2) %
+ (qi / vdr)); // x block quant index when casting the quants to int
+
+ tmp += vec_dot_iq1_m_q8_1(&x[ibx], &y[iby], iqs);
+ }
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (item_ct1.get_local_id(2) == 0) {
+ dst[row] = tmp;
+ }
+}
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
+ const void *__restrict__ vy,
+ float *__restrict__ dst, const int ncols,
+ const int nrows,
+ const sycl::nd_item<3> &item_ct1) {
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+
+ if (row >= nrows) {
+ return;
+ }
+
+ const int blocks_per_row = ncols / qk;
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+ float tmp = 0.0f;
+
+ const block_q_t * x = (const block_q_t *) vx;
+ const block_q8_1 * y = (const block_q8_1 *) vy;
+
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+ i += blocks_per_warp) {
+ const int ibx = row*blocks_per_row + i; // x block index
+
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+ const int iqs =
+ vdr *
+ (item_ct1.get_local_id(2) %
+ (qi / vdr)); // x block quant index when casting the quants to int
+
+ tmp += vec_dot_iq4_nl_q8_1(&x[ibx], &y[iby], iqs);
+ }
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (item_ct1.get_local_id(2) == 0) {
+ dst[row] = tmp;
+ }
+}
+
+
+template <int qk, int qi, typename block_q_t, int vdr>
+static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
+ const void *__restrict__ vy,
+ float *__restrict__ dst, const int ncols,
+ const int nrows,
+ const sycl::nd_item<3> &item_ct1) {
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+
+ if (row >= nrows) {
+ return;
+ }
+
+ const int blocks_per_row = ncols / qk;
+ const int blocks_per_warp = vdr * WARP_SIZE / qi;
+
+// partial sum for each thread
+ float tmp = 0.0f;
+
+ const block_q_t * x = (const block_q_t *) vx;
+ const block_q8_1 * y = (const block_q8_1 *) vy;
+
+ for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row;
+ i += blocks_per_warp) {
+ const int ibx = row*blocks_per_row + i; // x block index
+
+ const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
+
+ const int iqs =
+ vdr *
+ (item_ct1.get_local_id(2) %
+ (qi / vdr)); // x block quant index when casting the quants to int
+
+ tmp += vec_dot_iq4_xs_q8_1(&x[ibx], &y[iby], iqs);
+ }
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+ tmp +=
+ dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+ }
+
+ if (item_ct1.get_local_id(2) == 0) {
+ dst[row] = tmp;
+ }
+}
+
+static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK4_0 == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
+ VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK4_1 == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
+ VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK5_0 == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
+ VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK5_1 == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
+ VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK8_0 == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
+ VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
+ VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
+ VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
+ VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
+ VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
+ VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+
+static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+ stream->submit([&](sycl::handler &cgh) {
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+ auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
+ auto ksigns64_ptr_ct1 = &ksigns64[0];
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+ auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0];
+ auto ksigns64_ptr_ct1 = &ksigns64[0];
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+ auto iq3xxs_grid_ptr_ct1 = &iq3xxs_grid[0];
+ auto ksigns64_ptr_ct1 = &ksigns64[0];
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+ auto iq3s_grid_ptr_ct1 = &iq3s_grid[0];
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+ auto iq1s_grid_ptr_ct1 = &iq1s_grid_gpu[0];
+ auto ksigns64_ptr_ct1 = &ksigns64[0];
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+ stream->submit([&](sycl::handler &cgh) {
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK4_NL == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
+ float *dst, const int ncols,
+ const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_K == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+ {
+
+ stream->submit([&](sycl::handler &cgh) {
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
+
+void ggml_sycl_op_mul_mat_vec_q(
+ ggml_backend_sycl_context & ctx,
+ const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+ const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
+ float *dst_dd_i, const int64_t row_low, const int64_t row_high,
+ const int64_t src1_ncols, const int64_t src1_padded_col_size,
+ const dpct::queue_ptr &stream) {
+
+ const int64_t ne10 = src1->ne[0];
+ GGML_ASSERT(ne10 % QK8_1 == 0);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t row_diff = row_high - row_low;
+
+ int id;
+ SYCL_CHECK(
+ CHECK_TRY_ERROR(id = get_current_device_id()));
+ const size_t q8_1_ts = sizeof(block_q8_1);
+ const size_t q8_1_bs = QK8_1;
+ // the main device has a larger memory buffer to hold the results from all GPUs
+ // nrows_dst == nrows of the matrix that the kernel writes into
+ const int64_t nrows_dst = id == ctx.device ? ne00 : row_diff;
+ for (int i = 0; i < src1_ncols; i++)
+ {
+ const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs;
+ const char* src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
+ float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
+ switch (src0->type) {
+ case GGML_TYPE_Q4_0:
+ mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q4_1:
+ mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_0:
+ mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_1:
+ mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q8_0:
+ mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q2_K:
+ mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q3_K:
+ mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q4_K:
+ mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q5_K:
+ mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_Q6_K:
+ mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ1_S:
+ mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ1_M:
+ mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ2_XXS:
+ mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ2_XS:
+ mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ2_S:
+ mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ3_XXS:
+ mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ3_S:
+ mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ4_NL:
+ mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ case GGML_TYPE_IQ4_XS:
+ mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+ }
+ (void) src1;
+ (void) dst;
+ (void) src1_ddf_i;
+}
diff --git a/ggml/src/ggml-sycl/mmvq.hpp b/ggml/src/ggml-sycl/mmvq.hpp
new file mode 100644
index 00000000..049b43d4
--- /dev/null
+++ b/ggml/src/ggml-sycl/mmvq.hpp
@@ -0,0 +1,27 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_MMVQ_HPP
+#define GGML_SYCL_MMVQ_HPP
+
+#include "common.hpp"
+
+
+void ggml_sycl_op_mul_mat_vec_q(
+ ggml_backend_sycl_context & ctx,
+ const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+ const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
+ float *dst_dd_i, const int64_t row_low, const int64_t row_high,
+ const int64_t src1_ncols, const int64_t src1_padded_row_size,
+ const dpct::queue_ptr &stream);
+
+#endif // GGML_SYCL_MMVQ_HPP
diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp
new file mode 100644
index 00000000..cccf87d0
--- /dev/null
+++ b/ggml/src/ggml-sycl/norm.cpp
@@ -0,0 +1,374 @@
+#include "norm.hpp"
+
+static void norm_f32(const float* x, float* dst, const int ncols, const float eps,
+ const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+ const int tid = item_ct1.get_local_id(2);
+
+ const int nthreads = item_ct1.get_local_range(2);
+ const int nwarps = nthreads / WARP_SIZE;
+ assert(nwarps % WARP_SIZE == 0);
+ sycl::float2 mean_var = sycl::float2(0.f, 0.f);
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xi = x[row * ncols + col];
+ mean_var.x() += xi;
+ mean_var.y() += xi * xi;
+ }
+
+ // sum up partial sums
+ mean_var = warp_reduce_sum(mean_var, item_ct1);
+ if (block_size > WARP_SIZE) {
+
+ int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
+ int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = mean_var;
+ }
+ /*
+ DPCT1118:0: SYCL group functions and algorithms must be encountered in
+ converged control flow. You may need to adjust the code.
+ */
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+ mean_var = 0.f;
+ int nreduce = nwarps / WARP_SIZE;
+ for (size_t i = 0; i < nreduce; i += 1)
+ {
+ mean_var += s_sum[lane_id + i * WARP_SIZE];
+ }
+ mean_var = warp_reduce_sum(mean_var, item_ct1);
+ }
+
+ const float mean = mean_var.x() / ncols;
+ const float var = mean_var.y() / ncols - mean * mean;
+ const float inv_std = sycl::rsqrt(var + eps);
+
+ for (int col = tid; col < ncols; col += block_size) {
+ dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std;
+ }
+}
+
+static void group_norm_f32(const float* x, float* dst, const int group_size, const int ne_elements, const float eps,
+ const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
+ int start = item_ct1.get_group(2) * group_size;
+ int end = start + group_size;
+ const int nthreads = item_ct1.get_local_range(2);
+ const int nwarps = nthreads / WARP_SIZE;
+ assert(nwarps % WARP_SIZE == 0);
+ start += item_ct1.get_local_id(2);
+ int nreduce = nwarps / WARP_SIZE;
+
+ if (end >= ne_elements) {
+ end = ne_elements;
+ }
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int j = start; j < end; j += block_size) {
+ tmp += x[j];
+ }
+
+ tmp = warp_reduce_sum(tmp, item_ct1);
+ if (block_size > WARP_SIZE) {
+
+ int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
+ int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = tmp;
+ }
+ /*
+ DPCT1118:1: SYCL group functions and algorithms must be encountered in
+ converged control flow. You may need to adjust the code.
+ */
+ /*
+ DPCT1065:54: Consider replacing sycl::nd_item::barrier() with
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
+ better performance if there is no access to global memory.
+ */
+ item_ct1.barrier();
+ tmp = 0.f;
+ for (size_t i = 0; i < nreduce; i += 1)
+ {
+ tmp += s_sum[lane_id + i * WARP_SIZE];
+ }
+ tmp = warp_reduce_sum(tmp, item_ct1);
+ }
+
+ float mean = tmp / group_size;
+ tmp = 0.0f;
+
+ for (int j = start; j < end; j += block_size) {
+ float xi = x[j] - mean;
+ dst[j] = xi;
+ tmp += xi * xi;
+ }
+
+ tmp = warp_reduce_sum(tmp, item_ct1);
+ if (block_size > WARP_SIZE) {
+
+ int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
+ int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = tmp;
+ }
+ /*
+ DPCT1118:2: SYCL group functions and algorithms must be encountered in
+ converged control flow. You may need to adjust the code.
+ */
+ /*
+ DPCT1065:55: Consider replacing sycl::nd_item::barrier() with
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
+ better performance if there is no access to global memory.
+ */
+ item_ct1.barrier();
+ tmp = 0.f;
+ for (size_t i = 0; i < nreduce; i += 1)
+ {
+ tmp += s_sum[lane_id + i * WARP_SIZE];
+ }
+ tmp = warp_reduce_sum(tmp, item_ct1);
+ }
+
+ float variance = tmp / group_size;
+ float scale = sycl::rsqrt(variance + eps);
+ for (int j = start; j < end; j += block_size) {
+ dst[j] *= scale;
+ }
+}
+
+static void rms_norm_f32(const float* x, float* dst, const int ncols, const float eps,
+ const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
+ const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
+ item_ct1.get_local_id(1);
+ const int tid = item_ct1.get_local_id(2);
+ const int nthreads = item_ct1.get_local_range(2);
+ const int nwarps = nthreads / WARP_SIZE;
+ assert(nwarps % WARP_SIZE == 0);
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xi = x[row * ncols + col];
+ tmp += xi * xi;
+ }
+
+ // sum up partial sums
+ tmp = warp_reduce_sum(tmp, item_ct1);
+ if (block_size > WARP_SIZE) {
+
+ int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
+ int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = tmp;
+ }
+ /*
+ DPCT1118:3: SYCL group functions and algorithms must be encountered in
+ converged control flow. You may need to adjust the code.
+ */
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+ int nreduce = nwarps / WARP_SIZE;
+ tmp = 0.f;
+ for (size_t i = 0; i < nreduce; i += 1)
+ {
+ tmp += s_sum[lane_id + i * WARP_SIZE];
+ }
+ tmp = warp_reduce_sum(tmp, item_ct1);
+ }
+
+ const float mean = tmp / ncols;
+ const float scale = sycl::rsqrt(mean + eps);
+
+ for (int col = tid; col < ncols; col += block_size) {
+ dst[row * ncols + col] = scale * x[row * ncols + col];
+ }
+}
+
+static void norm_f32_sycl(const float* x, float* dst, const int ncols,
+ const int nrows, const float eps,
+ queue_ptr stream, int device) {
+ GGML_ASSERT(ncols % WARP_SIZE == 0);
+ if (ncols < 1024) {
+ const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+ stream->submit([&](sycl::handler& cgh) {
+ cgh.parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
+ block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ norm_f32(x, dst, ncols, eps, item_ct1,
+ nullptr, WARP_SIZE);
+ });
+ });
+ }
+ else {
+ const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
+ const sycl::range<3> block_dims(1, 1, work_group_size);
+ /*
+ DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ stream->submit([&](sycl::handler& cgh) {
+ sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
+ sycl::range<1>(work_group_size / WARP_SIZE), cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
+ block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ norm_f32(x, dst, ncols, eps, item_ct1,
+ get_pointer(s_sum_acc_ct1), work_group_size);
+ });
+ });
+ }
+}
+
+static void group_norm_f32_sycl(const float* x, float* dst,
+ const int num_groups, const int group_size,
+ const int ne_elements, queue_ptr stream, int device) {
+ static const float eps = 1e-6f;
+ if (group_size < 1024) {
+ const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+ stream->submit([&](sycl::handler& cgh) {
+ const float eps_ct4 = eps;
+ cgh.parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
+ block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ group_norm_f32(
+ x, dst, group_size, ne_elements, eps_ct4, item_ct1,
+ nullptr, WARP_SIZE);
+ });
+ });
+ }
+ else {
+ const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
+ const sycl::range<3> block_dims(1, 1, work_group_size);
+ /*
+ DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+
+ stream->submit([&](sycl::handler& cgh) {
+ sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
+ cgh);
+
+ const float eps_ct4 = eps;
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
+ block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ group_norm_f32(x, dst, group_size, ne_elements,
+ eps_ct4, item_ct1,
+ get_pointer(s_sum_acc_ct1), work_group_size);
+ });
+ });
+ }
+}
+
+static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
+ const int nrows, const float eps,
+ queue_ptr stream, int device) {
+ GGML_ASSERT(ncols % WARP_SIZE == 0);
+ // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
+ if (ncols < 1024) {
+ const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+ stream->submit([&](sycl::handler& cgh) {
+ cgh.parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
+ block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ rms_norm_f32(x, dst, ncols, eps, item_ct1,
+ nullptr, WARP_SIZE);
+ });
+ });
+ }
+ else {
+ const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
+ const sycl::range<3> block_dims(1, 1, work_group_size);
+ /*
+ DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ stream->submit([&](sycl::handler& cgh) {
+ sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
+ cgh);
+ cgh.parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
+ block_dims),
+ [=](sycl::nd_item<3> item_ct1)
+ [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ rms_norm_f32(x, dst, ncols, eps, item_ct1,
+ get_pointer(s_sum_acc_ct1), work_group_size);
+ });
+ });
+ }
+}
+
+void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
+ ggml_tensor* dst, const float* src0_dd,
+ const float* src1_dd, float* dst_dd,
+ const queue_ptr& main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
+
+ (void)src1;
+ (void)dst;
+ (void)src1_dd;
+}
+
+void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
+ const ggml_tensor* src1, ggml_tensor* dst,
+ const float* src0_dd, const float* src1_dd,
+ float* dst_dd,
+ const queue_ptr& main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ int num_groups = dst->op_params[0];
+ int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
+ group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
+
+ (void)src1;
+ (void)dst;
+ (void)src1_dd;
+}
+
+void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
+ const ggml_tensor* src1, ggml_tensor* dst,
+ const float* src0_dd, const float* src1_dd,
+ float* dst_dd,
+ const queue_ptr& main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
+
+ (void)src1;
+ (void)dst;
+ (void)src1_dd;
+}
diff --git a/ggml/src/ggml-sycl/norm.hpp b/ggml/src/ggml-sycl/norm.hpp
new file mode 100644
index 00000000..a9ad9156
--- /dev/null
+++ b/ggml/src/ggml-sycl/norm.hpp
@@ -0,0 +1,35 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_NORM_HPP
+#define GGML_SYCL_NORM_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
+ ggml_tensor* dst, const float* src0_dd,
+ const float* src1_dd, float* dst_dd,
+ const queue_ptr& main_stream);
+
+void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
+ const ggml_tensor* src1, ggml_tensor* dst,
+ const float* src0_dd, const float* src1_dd,
+ float* dst_dd,
+ const queue_ptr& main_stream);
+
+void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
+ const ggml_tensor* src1, ggml_tensor* dst,
+ const float* src0_dd, const float* src1_dd,
+ float* dst_dd,
+ const queue_ptr& main_stream);
+
+#endif // GGML_SYCL_NORM_HPP
diff --git a/ggml/src/ggml-sycl/presets.hpp b/ggml/src/ggml-sycl/presets.hpp
new file mode 100644
index 00000000..15ddcac1
--- /dev/null
+++ b/ggml/src/ggml-sycl/presets.hpp
@@ -0,0 +1,66 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_PRESETS_HPP
+#define GGML_SYCL_PRESETS_HPP
+
+#define GGML_SYCL_MAX_STREAMS 8
+#define GGML_SYCL_MAX_BUFFERS 256
+
+#define WARP_SIZE GGML_SYCL_WARP_SIZE
+#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
+
+#define SYCL_GELU_BLOCK_SIZE 256
+#define SYCL_SILU_BLOCK_SIZE 256
+#define SYCL_TANH_BLOCK_SIZE 256
+#define SYCL_RELU_BLOCK_SIZE 256
+#define SYCL_HARDSIGMOID_BLOCK_SIZE 256
+#define SYCL_HARDSWISH_BLOCK_SIZE 256
+#define SYCL_SQR_BLOCK_SIZE 256
+#define SYCL_CPY_BLOCK_SIZE 32
+#define SYCL_SCALE_BLOCK_SIZE 256
+#define SYCL_CLAMP_BLOCK_SIZE 256
+#define SYCL_ROPE_BLOCK_SIZE 256
+#define SYCL_ALIBI_BLOCK_SIZE 32
+#define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
+#define SYCL_QUANTIZE_BLOCK_SIZE 256
+#define SYCL_DEQUANTIZE_BLOCK_SIZE 256
+#define SYCL_GET_ROWS_BLOCK_SIZE 256
+#define SYCL_UPSCALE_BLOCK_SIZE 256
+#define SYCL_CONCAT_BLOCK_SIZE 256
+#define SYCL_PAD_BLOCK_SIZE 256
+#define SYCL_ACC_BLOCK_SIZE 256
+#define SYCL_IM2COL_BLOCK_SIZE 256
+#define SYCL_POOL2D_BLOCK_SIZE 256
+
+// dmmv = dequantize_mul_mat_vec
+#ifndef GGML_SYCL_DMMV_X
+#define GGML_SYCL_DMMV_X 32
+#endif
+#ifndef GGML_SYCL_MMV_Y
+#define GGML_SYCL_MMV_Y 1
+#endif
+
+#ifndef K_QUANTS_PER_ITERATION
+#define K_QUANTS_PER_ITERATION 2
+#else
+static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
+#endif
+
+#ifndef GGML_SYCL_PEER_MAX_BATCH_SIZE
+#define GGML_SYCL_PEER_MAX_BATCH_SIZE 128
+#endif // GGML_SYCL_PEER_MAX_BATCH_SIZE
+
+#define MUL_MAT_SRC1_COL_STRIDE 128
+
+#define QK_WARP_SIZE 32
+#endif // GGML_SYCL_PRESETS_HPP
diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp
new file mode 100644
index 00000000..6f507941
--- /dev/null
+++ b/ggml/src/ggml-sycl/rope.cpp
@@ -0,0 +1,275 @@
+#include "rope.hpp"
+
+struct rope_corr_dims {
+ float v[2];
+};
+
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+ const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low);
+ return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+ float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
+ float * cos_theta, float * sin_theta) {
+ // Get n-d rotational scaling corrected for extrapolation
+ float theta_interp = freq_scale * theta_extrap;
+ float theta = theta_interp;
+ if (ext_factor != 0.0f) {
+ float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+ // Get n-d magnitude scaling corrected for interpolation
+ mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
+ }
+ *cos_theta = sycl::cos(theta) * mscale;
+ *sin_theta = sycl::sin(theta) * mscale;
+}
+
+template<typename T, bool has_ff>
+static void rope_norm(
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+ item_ct1.get_local_id(1));
+
+ if (i0 >= ne0) {
+ return;
+ }
+
+ const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+
+ if (i0 >= n_dims) {
+ const int i = row*ne0 + i0;
+
+ dst[i + 0] = x[i + 0];
+ dst[i + 1] = x[i + 1];
+
+ return;
+ }
+
+ const int i = row*ne0 + i0;
+ const int i2 = row/p_delta_rows;
+
+ const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
+
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+ float cos_theta;
+ float sin_theta;
+
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+ const float x0 = x[i + 0];
+ const float x1 = x[i + 1];
+
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
+ dst[i + 1] = x0*sin_theta + x1*cos_theta;
+}
+
+template<typename T, bool has_ff>
+static void rope_neox(
+ const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
+ const sycl::nd_item<3> &item_ct1) {
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+ item_ct1.get_local_id(1));
+
+ if (i0 >= ne0) {
+ return;
+ }
+
+ const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+ item_ct1.get_local_id(2);
+
+ if (i0 >= n_dims) {
+ const int i = row*ne0 + i0;
+
+ dst[i + 0] = x[i + 0];
+ dst[i + 1] = x[i + 1];
+
+ return;
+ }
+
+ const int i = row*ne0 + i0/2;
+ const int i2 = row/p_delta_rows;
+
+ const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
+
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+ float cos_theta;
+ float sin_theta;
+
+ rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+ const float x0 = x[i + 0];
+ const float x1 = x[i + n_dims/2];
+
+ dst[i + 0] = x0*cos_theta - x1*sin_theta;
+ dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+}
+
+template <typename T>
+static void rope_norm_sycl(
+ const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
+ GGML_ASSERT(ne0 % 2 == 0);
+ const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
+ const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
+ const sycl::range<3> block_nums(1, num_blocks_x, nr);
+
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ if (freq_factors == nullptr) {
+ /*
+ DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ rope_norm<T, false>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
+ ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
+ item_ct1);
+ });
+ } else {
+ /*
+ DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
+ the limit. To get the device limit, query
+ info::device::max_work_group_size. Adjust the work-group size if needed.
+ */
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ rope_norm<T, true>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
+ ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
+ item_ct1);
+ });
+ }
+}
+
+template <typename T>
+static void rope_neox_sycl(
+ const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
+ GGML_ASSERT(ne0 % 2 == 0);
+ const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
+ const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
+ const sycl::range<3> block_nums(1, num_blocks_x, nr);
+
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+ dpct::has_capability_or_fail(stream->get_device(),
+ {sycl::aspect::fp16});
+
+ if (freq_factors == nullptr) {
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ rope_neox<T, false>(x, dst, ne0, n_dims, pos, freq_scale,
+ p_delta_rows, ext_factor, attn_factor,
+ corr_dims, theta_scale, freq_factors,
+ item_ct1);
+ });
+ } else {
+ stream->parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) {
+ rope_neox<T, true>(x, dst, ne0, n_dims, pos, freq_scale,
+ p_delta_rows, ext_factor, attn_factor,
+ corr_dims, theta_scale, freq_factors,
+ item_ct1);
+ });
+ }
+}
+
+void ggml_sycl_op_rope(
+ ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) {
+ const ggml_tensor * src2 = dst->src[2];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+ GGML_ASSERT(src0->type == dst->type);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t nr = ggml_nrows(src0);
+
+ //const int n_past = ((int32_t *) dst->op_params)[0];
+ const int n_dims = ((int32_t *) dst->op_params)[1];
+ const int mode = ((int32_t *) dst->op_params)[2];
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+ // RoPE alteration for extended context
+ float freq_base;
+ float freq_scale;
+ float ext_factor;
+ float attn_factor;
+ float beta_fast;
+ float beta_slow;
+
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+
+ const bool is_neox = mode & 2;
+
+ const int32_t * pos = (const int32_t *) src1_dd;
+
+ const float * freq_factors = nullptr;
+ if (src2 != nullptr) {
+ freq_factors = (const float *) src2->data;
+ }
+
+ rope_corr_dims corr_dims;
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
+
+ // compute
+ if (is_neox) {
+ if (src0->type == GGML_TYPE_F32) {
+ rope_neox_sycl(
+ (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, main_stream
+ );
+ } else if (src0->type == GGML_TYPE_F16) {
+ rope_neox_sycl(
+ (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, main_stream
+ );
+ } else {
+ GGML_ASSERT(false);
+ }
+ } else {
+ if (src0->type == GGML_TYPE_F32) {
+ rope_norm_sycl(
+ (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, main_stream
+ );
+ } else if (src0->type == GGML_TYPE_F16) {
+ rope_norm_sycl(
+ (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+ attn_factor, corr_dims, freq_factors, main_stream
+ );
+ } else {
+ GGML_ASSERT(false);
+ }
+ }
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
diff --git a/ggml/src/ggml-sycl/rope.hpp b/ggml/src/ggml-sycl/rope.hpp
new file mode 100644
index 00000000..00354c31
--- /dev/null
+++ b/ggml/src/ggml-sycl/rope.hpp
@@ -0,0 +1,22 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_ROPE_HPP
+#define GGML_SYCL_ROPE_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_rope(
+ ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream);
+
+#endif // GGML_SYCL_ROPE_HPP
diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp
new file mode 100644
index 00000000..17a542e4
--- /dev/null
+++ b/ggml/src/ggml-sycl/softmax.cpp
@@ -0,0 +1,251 @@
+#include "norm.hpp"
+
+template <bool vals_smem, int ncols_template, int block_size_template>
+static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
+ const int nrows_y, const float scale, const float max_bias, const float m0,
+ const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
+
+ const int tid = item_ct1.get_local_id(2);
+ const int rowx = item_ct1.get_group(2);
+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
+
+ const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
+
+ const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
+ const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
+ const int nthreads = block_size;
+ const int nwarps = nthreads / WARP_SIZE;
+ int nreduce = nwarps / WARP_SIZE;
+ float slope = 1.0f;
+
+ // ALiBi
+ if (max_bias > 0.0f) {
+ const uint32_t h = rowx/nrows_y; // head index
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slope = sycl::pow(base, float(exp));
+ }
+
+ float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
+ float max_val = -INFINITY;
+
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
+ const int col = col0 + tid;
+
+ if (ncols_template == 0 && col >= ncols) {
+ break;
+ }
+
+ const int ix = rowx*ncols + col;
+ const int iy = rowy*ncols + col;
+
+ const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
+
+ vals[col] = val;
+ max_val = sycl::max(max_val, val);
+ }
+
+ // find the max value in the block
+ max_val = warp_reduce_max(max_val, item_ct1);
+ if (block_size > WARP_SIZE) {
+ if (warp_id == 0) {
+ buf[lane_id] = -INFINITY;
+ for (size_t i = 1; i < nreduce; i += 1)
+ buf[lane_id + i * WARP_SIZE] = -INFINITY;
+ }
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ if (lane_id == 0) {
+ buf[warp_id] = max_val;
+ }
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+ max_val = buf[lane_id];
+ for (size_t i = 1; i < nreduce; i += 1)
+ {
+ max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
+ }
+ max_val = warp_reduce_max(max_val, item_ct1);
+ }
+
+ float tmp = 0.f;
+#pragma unroll
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
+ const int col = col0 + tid;
+ if (ncols_template == 0 && col >= ncols) {
+ break;
+ }
+
+ const float val = sycl::native::exp(vals[col] - max_val);
+ tmp += val;
+ vals[col] = val;
+ }
+
+ // find the sum of exps in the block
+ tmp = warp_reduce_sum(tmp, item_ct1);
+ if (block_size > WARP_SIZE) {
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+ if (warp_id == 0) {
+ buf[lane_id] = 0.f;
+ for (size_t i = 1; i < nreduce; i += 1)
+ buf[lane_id + i * WARP_SIZE] = 0.f;
+ }
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ if (lane_id == 0) {
+ buf[warp_id] = tmp;
+ }
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ tmp = buf[lane_id];
+ for (size_t i = 1; i < nreduce; i += 1)
+ {
+ tmp += buf[lane_id + i * WARP_SIZE];
+ }
+ tmp = warp_reduce_sum(tmp, item_ct1);
+ }
+
+ const float inv_sum = 1.f / tmp;
+
+#pragma unroll
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
+ const int col = col0 + tid;
+
+ if (ncols_template == 0 && col >= ncols) {
+ return;
+ }
+
+ const int idst = rowx*ncols + col;
+ dst[idst] = vals[col] * inv_sum;
+ }
+}
+
+template <bool vals_smem, int ncols_template, int block_size_template>
+static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
+ const int nrows_y, const float scale, const float max_bias, const float m0,
+ const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
+ const size_t n_local_scratch, queue_ptr stream) {
+ stream->submit([&](sycl::handler &cgh) {
+ sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
+
+ cgh.parallel_for(
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+ soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
+ nrows_y, scale, max_bias, m0,
+ m1, n_head_log2, item_ct1,
+ get_pointer(local_buf_acc));
+ });
+ });
+}
+
+static void soft_max_f32_sycl(const float * x, const float * mask,
+ float * dst, const int ncols_x, const int nrows_x,
+ const int nrows_y, const float scale, const float max_bias,
+ queue_ptr stream, int device) {
+ int nth = WARP_SIZE;
+ int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
+ while (nth < ncols_x && nth < max_block_size) nth *= 2;
+ if (nth>max_block_size) nth = max_block_size;
+
+ const sycl::range<3> block_dims(1, 1, nth);
+ const sycl::range<3> block_nums(1, 1, nrows_x);
+ const size_t n_val_tmp = nth / WARP_SIZE;
+ const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp);
+
+ const uint32_t n_head_kv = nrows_x/nrows_y;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
+ if (n_local_scratch*sizeof(float) < local_mem_size) {
+ if (ncols_x > max_block_size) {
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ return;
+ }
+ switch (ncols_x) {
+ case 32:
+ soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ case 64:
+ soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ case 128:
+ soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ case 256:
+ soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ case 512:
+ soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ case 1024:
+ soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ case 2048:
+ soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ case 4096:
+ soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ default:
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, n_local_scratch, stream);
+ break;
+ }
+ } else {
+ soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
+ max_bias, m0, m1, n_head_log2, block_nums,
+ block_dims, WARP_SIZE, stream);
+ }
+}
+
+void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_dd, const float *src1_dd,
+ float *dst_dd,
+ const queue_ptr &main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
+#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t nrows_x = ggml_nrows(src0);
+ const int64_t nrows_y = src0->ne[1];
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, dst->op_params + 1, sizeof(float));
+
+ soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
+ nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
+}
diff --git a/ggml/src/ggml-sycl/softmax.hpp b/ggml/src/ggml-sycl/softmax.hpp
new file mode 100644
index 00000000..bdb8f712
--- /dev/null
+++ b/ggml/src/ggml-sycl/softmax.hpp
@@ -0,0 +1,24 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_SOFTMAX_HPP
+#define GGML_SYCL_SOFTMAX_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0,
+ const ggml_tensor *src1, ggml_tensor *dst,
+ const float *src0_dd, const float *src1_dd,
+ float *dst_dd,
+ const queue_ptr &main_stream);
+
+#endif // GGML_SYCL_SOFTMAX_HPP
diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp
new file mode 100644
index 00000000..d2dccade
--- /dev/null
+++ b/ggml/src/ggml-sycl/vecdotq.hpp
@@ -0,0 +1,1140 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_VECDOTQ_HPP
+#define GGML_SYCL_VECDOTQ_HPP
+
+#include "dpct/helper.hpp"
+
+typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
+
+static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
+ const uint16_t* x16 =
+ (const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte
+ // alignment
+
+ int x32 = 0;
+ x32 |= x16[0] << 0;
+ x32 |= x16[1] << 16;
+
+ return x32;
+}
+
+static __dpct_inline__ int get_int_from_uint8(
+ const uint8_t* x8,
+ const int& i32) {
+ const uint16_t* x16 =
+ (const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte
+ // alignment
+
+ int x32 = 0;
+ x32 |= x16[0] << 0;
+ x32 |= x16[1] << 16;
+
+ return x32;
+}
+
+static __dpct_inline__ int get_int_from_int8_aligned(
+ const int8_t* x8,
+ const int& i32) {
+ return *(
+ (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+}
+
+static __dpct_inline__ int get_int_from_uint8_aligned(
+ const uint8_t* x8,
+ const int& i32) {
+ return *(
+ (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
+}
+
+static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4,
+ const uint8_t *values,
+ int &val1, int &val2) {
+
+ uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
+ aux32 = q4 & 0x0f0f0f0f;
+ uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
+ uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
+ val1 = v1 | (v2 << 16);
+ aux32 = (q4 >> 4) & 0x0f0f0f0f;
+ v1 = values[q8[0]] | (values[q8[1]] << 8);
+ v2 = values[q8[2]] | (values[q8[3]] << 8);
+ val2 = v1 | (v2 << 16);
+}
+
+#define VDR_Q2_K_Q8_1_MMVQ 1
+
+// contiguous v/x values
+static __dpct_inline__ float vec_dot_q2_K_q8_1_impl_mmvq(
+ const int &v, const int *__restrict__ u, const uint8_t *__restrict__ scales,
+ const sycl::half2 &dm2, const float *__restrict__ d8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR2_K; ++i) {
+ const int sc = scales[2*i];
+
+ const int vi = (v >> (2*i)) & 0x03030303;
+
+ sumf_d +=
+ d8[i] * (dpct::dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
+
+ // fill int with 4x m
+ int m = sc >> 4;
+ m |= m << 8;
+ m |= m << 16;
+ sumf_m += d8[i] *
+ dpct::dp4a(
+ m, u[i],
+ 0); // multiply constant q2_K part with sum of q8_1 values
+ }
+
+ const sycl::float2 dm2f =
+ dm2.convert<float, sycl::rounding_mode::automatic>();
+
+ return dm2f.x() * sumf_d - dm2f.y() * sumf_m;
+}
+
+
+#define VDR_Q3_K_Q8_1_MMVQ 1
+
+// contiguous v/x values
+static __dpct_inline__ float vec_dot_q3_K_q8_1_impl_mmvq(
+ const int &vl, const int &vh, const int *__restrict__ u,
+ const uint8_t *__restrict__ scales, const int &scale_offset,
+ const float &d3, const float *__restrict__ d8) {
+
+ float sumf = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR3_K; ++i) {
+ const int isc = scale_offset + 2*i;
+
+ const int isc_low = isc % (QK_K/32);
+ const int sc_shift_low = 4 * (isc / (QK_K/32));
+ const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF;
+
+ const int isc_high = isc % (QK_K/64);
+ const int sc_shift_high = 2 * (isc / (QK_K/64));
+ const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
+
+ const int sc = (sc_low | sc_high) - 32;
+
+ const int vil = (vl >> (2*i)) & 0x03030303;
+
+ const int vih = ((vh >> i) << 2) & 0x04040404;
+
+ const int vi =
+ dpct::vectorized_binary<sycl::char4>(vil, vih, dpct::sub_sat());
+
+ sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
+ }
+
+ return d3 * sumf;
+}
+
+#define VDR_Q4_K_Q8_1_MMVQ 2
+
+// contiguous v/x values
+static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_vmmq(
+ const int *__restrict__ v, const int *__restrict__ u,
+ const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
+ const sycl::half2 &dm4, const float *__restrict__ d8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR4_K; ++i) {
+ const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
+ const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
+
+ const int dot1 =
+ dpct::dp4a(v1i, u[2 * i + 1],
+ dpct::dp4a(v0i, u[2 * i + 0], 0)); // SIMD dot product
+ const int dot2 =
+ dpct::dp4a(0x01010101, u[2 * i + 1],
+ dpct::dp4a(0x01010101, u[2 * i + 0], 0)); // sum of u
+
+ sumf_d += d8[i] * (dot1 * sc[i]);
+ sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
+ }
+
+ const sycl::float2 dm4f =
+ dm4.convert<float, sycl::rounding_mode::automatic>();
+
+ return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
+}
+
+
+#define VDR_Q5_K_Q8_1_MMVQ 2
+
+// contiguous v/x values
+static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_vmmq(
+ const int *__restrict__ vl, const int *__restrict__ vh,
+ const int *__restrict__ u, const uint8_t *__restrict__ sc,
+ const uint8_t *__restrict__ m, const sycl::half2 &dm5,
+ const float *__restrict__ d8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR5_K; ++i) {
+ const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
+ const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
+
+ const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
+ const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
+
+ const int v0i = vl0i | vh0i;
+ const int v1i = vl1i | vh1i;
+
+ const int dot1 =
+ dpct::dp4a(v0i, u[2 * i + 0],
+ dpct::dp4a(v1i, u[2 * i + 1], 0)); // SIMD dot product
+ const int dot2 =
+ dpct::dp4a(0x01010101, u[2 * i + 0],
+ dpct::dp4a(0x01010101, u[2 * i + 1], 0)); // sum of u
+
+ sumf_d += d8[i] * (dot1 * sc[i]);
+ sumf_m += d8[i] * (dot2 * m[i]);
+
+ }
+
+ const sycl::float2 dm5f =
+ dm5.convert<float, sycl::rounding_mode::automatic>();
+
+ return dm5f.x() * sumf_d - dm5f.y() * sumf_m;
+}
+
+
+#define VDR_Q6_K_Q8_1_MMVQ 1
+
+// contiguous v/x values
+static __dpct_inline__ float
+vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh,
+ const int *__restrict__ u,
+ const int8_t *__restrict__ scales, const float &d,
+ const float *__restrict__ d8) {
+
+ float sumf = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR6_K; ++i) {
+ const int sc = scales[4*i];
+
+ const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+
+ const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
+
+ const int vi = dpct::vectorized_binary<sycl::char4>(
+ (vil | vih), 0x20202020, dpct::sub_sat()); // vi = (vil | vih) - 32
+
+ sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
+ }
+
+ return d*sumf;
+}
+
+// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
+// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
+
+#define VDR_Q4_0_Q8_1_MMVQ 2
+#define VDR_Q4_0_Q8_1_MMQ 4
+
+template <int vdr>
+static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u,
+ const float &d4,
+ const sycl::half2 &ds8) {
+ int sumi = 0;
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+ const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+ // SIMD dot product of quantized values
+ sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);
+ sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
+ }
+
+ const sycl::float2 ds8f =
+ ds8.convert<float, sycl::rounding_mode::automatic>();
+
+ // second part effectively subtracts 8 from each quant value
+ return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y());
+}
+
+#define VDR_Q4_1_Q8_1_MMVQ 2
+#define VDR_Q4_1_Q8_1_MMQ 4
+
+template <int vdr>
+static __dpct_inline__ float vec_dot_q4_1_q8_1_impl(const int *v, const int *u,
+ const sycl::half2 &dm4,
+ const sycl::half2 &ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+ const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+ // SIMD dot product of quantized values
+ sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);
+ sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
+ }
+
+#ifdef GGML_SYCL_F16
+ const sycl::float2 tmp =
+ (dm4 * ds8).convert<float, sycl::rounding_mode::automatic>();
+ const float d4d8 = tmp.x();
+ const float m4s8 = tmp.y();
+#else
+ const sycl::float2 dm4f =
+ dm4.convert<float, sycl::rounding_mode::automatic>();
+ const sycl::float2 ds8f =
+ ds8.convert<float, sycl::rounding_mode::automatic>();
+ const float d4d8 = dm4f.x() * ds8f.x();
+ const float m4s8 = dm4f.y() * ds8f.y();
+#endif // GGML_SYCL_F16
+
+ // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
+ return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
+}
+
+#define VDR_Q5_0_Q8_1_MMVQ 2
+#define VDR_Q5_0_Q8_1_MMQ 4
+
+template <int vdr>
+static __dpct_inline__ float
+vec_dot_q5_0_q8_1_impl(const int *vl, const int *vh, const int *u,
+ const float &d5, const sycl::half2 &ds8) {
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+ vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
+ vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+ vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+ vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+ sumi = dpct::dp4a(vi0, u[2 * i + 0],
+ sumi); // SIMD dot product of quantized values
+
+ int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+ vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
+ vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
+ vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
+ vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
+ sumi = dpct::dp4a(vi1, u[2 * i + 1],
+ sumi); // SIMD dot product of quantized values
+ }
+
+ const sycl::float2 ds8f =
+ ds8.convert<float, sycl::rounding_mode::automatic>();
+
+ // second part effectively subtracts 16 from each quant value
+ return d5 * (sumi * ds8f.x() - (16 * vdr / QI5_0) * ds8f.y());
+}
+
+#define VDR_Q5_1_Q8_1_MMVQ 2
+#define VDR_Q5_1_Q8_1_MMQ 4
+
+template <int vdr>
+static __dpct_inline__ float
+vec_dot_q5_1_q8_1_impl(const int *vl, const int *vh, const int *u,
+ const sycl::half2 &dm5, const sycl::half2 &ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+ vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
+ vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+ vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+ vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+ sumi = dpct::dp4a(vi0, u[2 * i + 0],
+ sumi); // SIMD dot product of quantized values
+
+ int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+ vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
+ vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
+ vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
+ vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
+ sumi = dpct::dp4a(vi1, u[2 * i + 1],
+ sumi); // SIMD dot product of quantized values
+ }
+
+#ifdef GGML_SYCL_F16
+ const sycl::float2 tmp =
+ (dm5 * ds8).convert<float, sycl::rounding_mode::automatic>();
+ const float d5d8 = tmp.x();
+ const float m5s8 = tmp.y();
+
+
+#else
+ const sycl::float2 dm5f =
+ dm5.convert<float, sycl::rounding_mode::automatic>();
+ const sycl::float2 ds8f =
+ ds8.convert<float, sycl::rounding_mode::automatic>();
+ const float d5d8 = dm5f.x() * ds8f.x();
+ const float m5s8 = dm5f.y() * ds8f.y();
+#endif // GGML_SYCL_F16
+
+ // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
+ return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
+}
+
+#define VDR_Q8_0_Q8_1_MMVQ 2
+#define VDR_Q8_0_Q8_1_MMQ 8
+
+template <int vdr>
+static __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int *v, const int *u,
+ const float &d8_0,
+ const float &d8_1) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ // SIMD dot product of quantized values
+ sumi = dpct::dp4a(v[i], u[i], sumi);
+ }
+
+ return d8_0*d8_1 * sumi;
+}
+
+template <int vdr>
+static __dpct_inline__ float vec_dot_q8_1_q8_1_impl(const int *v, const int *u,
+ const sycl::half2 &dm8,
+ const sycl::half2 &ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ // SIMD dot product of quantized values
+ sumi = dpct::dp4a(v[i], u[i], sumi);
+ }
+
+#ifdef GGML_SYCL_F16
+ const sycl::float2 tmp =
+ (dm8 * ds8).convert<float, sycl::rounding_mode::automatic>();
+ const float d8d8 = tmp.x();
+ const float m8s8 = tmp.y();
+#else
+ const sycl::float2 dm8f =
+ dm8.convert<float, sycl::rounding_mode::automatic>();
+ const sycl::float2 ds8f =
+ ds8.convert<float, sycl::rounding_mode::automatic>();
+ const float d8d8 = dm8f.x() * ds8f.x();
+ const float m8s8 = dm8f.y() * ds8f.y();
+#endif // GGML_SYCL_F16
+
+ // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
+ return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
+}
+
+static __dpct_inline__ float
+vec_dot_q4_0_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+ const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
+
+ int v[VDR_Q4_0_Q8_1_MMVQ];
+ int u[2*VDR_Q4_0_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
+ v[i] = get_int_from_uint8(bq4_0->qs, iqs + i);
+ u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
+ }
+
+ return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
+}
+
+static __dpct_inline__ float
+vec_dot_q4_1_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+ const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
+
+ int v[VDR_Q4_1_Q8_1_MMVQ];
+ int u[2*VDR_Q4_1_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
+ v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i);
+ u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);
+ }
+
+ return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
+}
+
+static __dpct_inline__ float
+vec_dot_q5_0_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+ const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
+
+ int vl[VDR_Q5_0_Q8_1_MMVQ];
+ int vh[VDR_Q5_0_Q8_1_MMVQ];
+ int u[2*VDR_Q5_0_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {
+ vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i);
+ vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i));
+ u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0);
+ }
+
+ return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds);
+}
+
+static __dpct_inline__ float
+vec_dot_q5_1_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+ const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
+
+ int vl[VDR_Q5_1_Q8_1_MMVQ];
+ int vh[VDR_Q5_1_Q8_1_MMVQ];
+ int u[2*VDR_Q5_1_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
+ vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i);
+ vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i));
+ u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);
+ }
+
+ return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
+}
+
+static __dpct_inline__ float
+vec_dot_q8_0_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+ const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
+
+ int v[VDR_Q8_0_Q8_1_MMVQ];
+ int u[VDR_Q8_0_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
+ v[i] = get_int_from_int8(bq8_0->qs, iqs + i);
+ u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
+ }
+
+ return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d,
+ bq8_1->ds[0]);
+}
+
+static __dpct_inline__ float
+vec_dot_q2_K_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+ const block_q2_K * bq2_K = (const block_q2_K *) vbq;
+
+ const int bq8_offset = QR2_K * (iqs / QI8_1);
+ const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+ const uint8_t * scales = bq2_K->scales + scale_offset;
+
+ const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs);
+ int u[QR2_K];
+ float d8[QR2_K];
+
+#pragma unroll
+ for (int i = 0; i < QR2_K; ++ i) {
+ u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+ d8[i] = bq8_1[bq8_offset + i].ds[0];
+ }
+
+ return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
+}
+
+static __dpct_inline__ float
+vec_dot_q3_K_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+ const block_q3_K * bq3_K = (const block_q3_K *) vbq;
+
+ const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
+ const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+ const float d = bq3_K->d;
+
+ const int vl = get_int_from_uint8(bq3_K->qs, iqs);
+
+ // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+ const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;
+
+ int u[QR3_K];
+ float d8[QR3_K];
+
+#pragma unroll
+ for (int i = 0; i < QR3_K; ++i) {
+ u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+ d8[i] = bq8_1[bq8_offset + i].ds[0];
+ }
+
+ return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
+}
+
+static __dpct_inline__ float
+vec_dot_q4_K_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+#ifndef GGML_QKK_64
+ const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+
+ int v[2];
+ int u[2*QR4_K];
+ float d8[QR4_K];
+
+ // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
+ const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
+
+ // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
+ // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
+ // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
+ // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
+
+ const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+ v[0] = q4[0];
+ v[1] = q4[4];
+
+ const uint16_t * scales = (const uint16_t *)bq4_K->scales;
+ uint16_t aux[2];
+ const int j = bq8_offset/2;
+ if (j < 2) {
+ aux[0] = scales[j+0] & 0x3f3f;
+ aux[1] = scales[j+2] & 0x3f3f;
+ } else {
+ aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+ aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+ }
+ const uint8_t * sc = (const uint8_t *)aux;
+ const uint8_t * m = sc + 2;
+
+ for (int i = 0; i < QR4_K; ++i) {
+ const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+ d8[i] = bq8i->ds[0];
+
+ const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+ u[2*i+0] = q8[0];
+ u[2*i+1] = q8[4];
+ }
+
+ return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
+
+#else
+
+#if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics
+ const block_q4_K * bq4_K = (const block_q4_K *) vbq;
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+ uint16_t aux16[2];
+ const uint8_t * s = (const uint8_t *)aux16;
+
+ const uint16_t * a = (const uint16_t *)bq4_K->scales;
+ aux16[0] = a[0] & 0x0f0f;
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+ const float dall = bq4_K->dm[0];
+ const float dmin = bq4_K->dm[1];
+
+ const float d8_1 = bq8_1[0].ds[0];
+ const float d8_2 = bq8_1[1].ds[1];
+
+ const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
+ const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
+ const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
+ const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
+
+ const int * q4 = (const int *)bq4_K->qs + (iqs/2);
+ const int v1 = q4[0];
+ const int v2 = q4[4];
+
+ const int dot1 = dpct::dp4a(ui2, v2 & 0x0f0f0f0f, dpct::dp4a(ui1, v1 & 0x0f0f0f0f, 0));
+ const int dot2 = dpct::dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, dpct::dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
+ const int dot3 = dpct::dp4a(0x01010101, ui2, dpct::dp4a(0x01010101, ui1, 0));
+ const int dot4 = dpct::dp4a(0x01010101, ui4, dpct::dp4a(0x01010101, ui3, 0));
+
+ sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
+ sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
+
+ return dall * sumf_d - dmin * sumf_m;
+
+#else
+ bad_arch();
+#endif // __SYCL_ARCH__ >= VER_4VEC
+
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_q5_K_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+#ifndef GGML_QKK_64
+ const block_q5_K * bq5_K = (const block_q5_K *) vbq;
+
+ int vl[2];
+ int vh[2];
+ int u[2*QR5_K];
+ float d8[QR5_K];
+
+ const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));
+ const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+ const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));
+
+ vl[0] = ql[0];
+ vl[1] = ql[4];
+
+ vh[0] = qh[0] >> bq8_offset;
+ vh[1] = qh[4] >> bq8_offset;
+
+ const uint16_t * scales = (const uint16_t *)bq5_K->scales;
+ uint16_t aux[2];
+ const int j = bq8_offset/2;
+ if (j < 2) {
+ aux[0] = scales[j+0] & 0x3f3f;
+ aux[1] = scales[j+2] & 0x3f3f;
+ } else {
+ aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+ aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+ }
+ const uint8_t * sc = (const uint8_t *)aux;
+ const uint8_t * m = sc + 2;
+
+#pragma unroll
+ for (int i = 0; i < QR5_K; ++i) {
+ const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+ d8[i] = bq8i->ds[0];
+
+ const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+ u[2*i+0] = q8[0];
+ u[2*i+1] = q8[4];
+ }
+
+ return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
+
+#else
+
+#if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics
+ const block_q5_K * bq5_K = (const block_q5_K *) vbq;
+
+ const int8_t * s = bq5_K->scales;
+
+ const float d = bq5_K->d;
+
+ const float d8_1 = bq8_1[0].ds[0];
+ const float d8_2 = bq8_1[1].ds[1];
+
+ const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
+ const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
+ const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2));
+ const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4);
+
+ const int * ql = (const int *)bq5_K->qs + (iqs/2);
+ const int vl1 = ql[0];
+ const int vl2 = ql[4];
+
+ const int step = 4 * (iqs/2); // 0, 4, 8, 12
+ const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6
+ const int in = step%8; // 0, 4, 0, 4
+ const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
+
+ const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
+ const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
+ const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
+ const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
+
+ const float sumf_d = d8_1 * (dpct::dp4a(ui1, v1, 0) * s[0] + dpct::dp4a(ui2, v2, 0) * s[1])
+ + d8_2 * (dpct::dp4a(ui3, v3, 0) * s[2] + dpct::dp4a(ui4, v4, 0) * s[3]);
+
+ return d * sumf_d;
+
+#else
+ bad_arch();
+#endif // __SYCL_ARCH__ >= VER_4VEC
+
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_q6_K_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+ const block_q6_K * bq6_K = (const block_q6_K *) vbq;
+
+ const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
+ const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
+ const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
+
+ const int vl = get_int_from_uint8(bq6_K->ql, iqs);
+ const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;
+
+ const int8_t * scales = bq6_K->scales + scale_offset;
+
+ int u[QR6_K];
+ float d8[QR6_K];
+
+#pragma unroll
+ for (int i = 0; i < QR6_K; ++i) {
+ u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
+ d8[i] = bq8_1[bq8_offset + 2 * i].ds[0];
+ }
+
+ return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
+}
+
+
+static __dpct_inline__ float
+vec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
+ const uint64_t *iq2xxs_grid, const uint8_t *ksigns_iq2xs,
+ const uint8_t *kmask_iq2xs) {
+#if QK_K == 256
+ const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
+
+ const int ib32 = iqs;
+ const uint16_t * q2 = bq2->qs + 4*ib32;
+ const uint8_t * aux8 = (const uint8_t *)q2;
+ const int8_t * q8 = bq8_1[ib32].qs;
+ uint32_t aux32 = q2[2] | (q2[3] << 16);
+ int sumi = 0;
+ for (int l = 0; l < 4; ++l) {
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
+ const uint8_t signs = ksigns_iq2xs[aux32 & 127];
+ for (int j = 0; j < 8; ++j) {
+ sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+ }
+ q8 += 8;
+ aux32 >>= 7;
+ }
+ const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.25f;
+ return d * sumi;
+#else
+ assert(false);
+ return 0.f;
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_iq2_xs_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
+ const uint64_t *iq2xs_grid, const uint64_t *ksigns64) {
+#if DPCT_COMPATIBILITY_TEMP >= \
+ MIN_CC_DP4A // lowest compute capability for integer intrinsics
+#if QK_K == 256
+ const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
+
+ const int ib32 = iqs;
+ const uint16_t * q2 = bq2->qs + 4*ib32;
+ const int8_t * q8 = bq8_1[ib32].qs;
+ const uint8_t ls1 = bq2->scales[ib32] & 0xf;
+ const uint8_t ls2 = bq2->scales[ib32] >> 4;
+ int sumi1 = 0;
+ for (int l = 0; l < 2; ++l) {
+ const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
+ grid[0] ^ signs[0], signs[0], std::minus<>());
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
+ grid[1] ^ signs[1], signs[1], std::minus<>());
+ sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);
+ sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);
+ q8 += 8;
+ }
+ int sumi2 = 0;
+ for (int l = 2; l < 4; ++l) {
+ const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511));
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
+ grid[0] ^ signs[0], signs[0], std::minus<>());
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
+ grid[1] ^ signs[1], signs[1], std::minus<>());
+ sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);
+ sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);
+ q8 += 8;
+ }
+ const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
+ return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
+#else
+ assert(false);
+ return 0.f;
+#endif
+#else
+ assert(false);
+ return 0.f;
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_iq2_s_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+#if QK_K == 256
+ const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
+
+ const int ib32 = iqs;
+ const int8_t * q8 = bq8_1[ib32].qs;
+ const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32;
+ const uint8_t ls1 = bq2->scales[ib32] & 0xf;
+ const uint8_t ls2 = bq2->scales[ib32] >> 4;
+ int sumi1 = 0;
+ for (int l = 0; l < 2; ++l) {
+ const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
+ const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
+ ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
+ std::equal_to<>());
+ const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
+ ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
+ std::equal_to<>());
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
+ grid[0] ^ signs0, signs0, std::minus<>());
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
+ grid[1] ^ signs1, signs1, std::minus<>());
+ sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1);
+ sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1);
+ q8 += 8;
+ }
+ int sumi2 = 0;
+ for (int l = 2; l < 4; ++l) {
+ const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300)));
+ const uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
+ ((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201,
+ std::equal_to<>());
+ const uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
+ ((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201,
+ std::equal_to<>());
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
+ grid[0] ^ signs0, signs0, std::minus<>());
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
+ grid[1] ^ signs1, signs1, std::minus<>());
+ sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2);
+ sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2);
+ q8 += 8;
+ }
+ const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f;
+ return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
+#else
+ assert(false);
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
+ const uint32_t *iq3xxs_grid, const uint64_t *ksigns64) {
+#if DPCT_COMPATIBILITY_TEMP >= \
+ MIN_CC_DP4A // lowest compute capability for integer intrinsics
+#if QK_K == 256
+ const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
+
+ const int ib32 = iqs;
+ const uint8_t * q3 = bq2->qs + 8*ib32;
+ const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32;
+ const int8_t * q8 = bq8_1[ib32].qs;
+ uint32_t aux32 = gas[0] | (gas[1] << 16);
+ int sumi = 0;
+ for (int l = 0; l < 4; ++l) {
+ const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0];
+ const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1];
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127));
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
+ grid1[0] ^ signs[0], signs[0], std::minus<>());
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
+ grid2[0] ^ signs[1], signs[1], std::minus<>());
+ sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
+ sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
+ q8 += 8;
+ aux32 >>= 7;
+ }
+ const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.5f;
+ return d * sumi;
+#else
+ assert(false);
+ return 0.f;
+#endif
+#else
+ assert(false);
+ return 0.f;
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
+ const uint32_t *iq3s_grid) {
+#if QK_K == 256
+ const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
+
+ const int ib32 = iqs;
+ const uint8_t * qs = bq2->qs + 8*ib32;
+ const int8_t * q8 = bq8_1[ib32].qs;
+ int sumi = 0;
+ for (int l = 0; l < 4; ++l) {
+ const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
+ const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
+ uint32_t signs0 = dpct::vectorized_binary<sycl::uchar4>(
+ ((bq2->signs[4 * ib32 + l] & 0xf) * 0x01010101) & 0x08040201,
+ 0x08040201, std::equal_to<>());
+ uint32_t signs1 = dpct::vectorized_binary<sycl::uchar4>(
+ ((bq2->signs[4 * ib32 + l] >> 4) * 0x01010101) & 0x08040201,
+ 0x08040201, std::equal_to<>());
+ const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
+ grid1[0] ^ signs0, signs0, std::minus<>());
+ const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
+ grid2[0] ^ signs1, signs1, std::minus<>());
+ sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
+ sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
+ q8 += 8;
+ }
+ const float d =
+ (float)bq2->d *
+ (1 + 2 * ((bq2->scales[ib32 / 2] >> 4 * (ib32 % 2)) & 0xf)) *
+ bq8_1[ib32].ds[0];
+ return d * sumi;
+#else
+ assert(false);
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs,
+ const uint32_t *iq1s_grid_gpu) {
+#if QK_K == 256
+ const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
+
+ const int ib32 = iqs;
+ int sumi = 0;
+ const int * q8 = (const int *)bq8_1[ib32].qs;
+ for (int l = 0; l < 4; ++l) {
+ const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
+ int grid0 = grid[0] & 0x0f0f0f0f;
+ int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
+ sumi = dpct::dp4a(q8[2 * l + 1], grid1,
+ dpct::dp4a(q8[2 * l + 0], grid0, sumi));
+ }
+
+ const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA;
+ const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1);
+ const float d = d1q * bq8_1[ib32].ds[0];
+ const float m = d1q * bq8_1[ib32].ds[1];
+ return d * sumi + m * delta;
+#else
+ assert(false);
+#endif
+}
+
+static __dpct_inline__ float
+vec_dot_iq1_m_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+#if QK_K == 256
+ const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
+
+ const int ib32 = iqs;
+ int sumi[2] = {0, 0};
+ float sumf[2] = {0.f, 0.f};
+
+ const int * q8 = (const int *)bq8_1[ib32].qs;
+ for (int l = 0; l < 4; ++l) {
+ const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 7) << 8)));
+ int grid0 = grid[0] & 0x0f0f0f0f;
+ int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
+ sumi[l / 2] = dpct::dp4a(q8[2 * l + 1], grid1,
+ dpct::dp4a(q8[2 * l + 0], grid0, sumi[l / 2]));
+ const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA;
+ const int sumy = dpct::dp4a(q8[2 * l + 1], 0x01010101,
+ dpct::dp4a(q8[2 * l + 0], 0x01010101, 0));
+ sumf[l/2] += delta*sumy;
+ }
+
+ iq1m_scale_t scale;
+ const uint16_t * sc = (const uint16_t *)bq1->scales;
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+ const float d = (float)scale.f16 * bq8_1[ib32].ds[0];
+ return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1));
+#else
+ assert(false);
+#endif
+}
+
+
+static __dpct_inline__ float
+vec_dot_iq4_nl_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+ const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
+
+ const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
+ const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs;
+
+ const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
+
+ int v1, v2;
+ int sumi1 = 0, sumi2 = 0;
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
+ const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
+ get_int_from_table_16(aux, values, v1, v2);
+ sumi1 = dpct::dp4a(v1, q8[l + 0], sumi1);
+ sumi2 = dpct::dp4a(v2, q8[l + 4], sumi2);
+ }
+
+ const float d = (float)bq->d * bq8_1->ds[0];
+ return d * (sumi1 + sumi2);
+}
+
+
+static __dpct_inline__ float
+vec_dot_iq4_xs_q8_1(const void *__restrict__ vbq,
+ const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
+
+#if QK_K == 256
+ const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
+ const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
+
+ // iqs is 0...7
+ const int ib32 = iqs;
+ const int32_t * q8 = (const int *)bq8_1[ib32].qs;
+ const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
+ const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
+ const float d = (float)bq4->d * (ls - 32) * bq8_1[ib32].ds[0];
+ int v1, v2;
+ int sumi1 = 0, sumi2 = 0;
+ for (int j = 0; j < 4; ++j) {
+ get_int_from_table_16(q4[j], values, v1, v2);
+ sumi1 = dpct::dp4a(v1, q8[j + 0], sumi1);
+ sumi2 = dpct::dp4a(v2, q8[j + 4], sumi2);
+ }
+ return d * (sumi1 + sumi2);
+#else
+ assert(false);
+#endif
+}
+
+#endif // GGML_SYCL_VECDOTQ_HPP
diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp
new file mode 100644
index 00000000..6bcd81a7
--- /dev/null
+++ b/ggml/src/ggml-vulkan.cpp
@@ -0,0 +1,7022 @@
+#include "ggml-vulkan.h"
+#include <vulkan/vulkan_core.h>
+#ifdef GGML_VULKAN_RUN_TESTS
+#include <chrono>
+#endif
+
+#include <vulkan/vulkan.hpp>
+
+#include <algorithm>
+#include <cmath>
+#include <iomanip>
+#include <iostream>
+#include <tuple>
+#include <vector>
+#include <sstream>
+#include <utility>
+#include <memory>
+#include <limits>
+#include <map>
+#include <memory>
+#include <mutex>
+
+#include "ggml.h"
+#include "ggml-backend-impl.h"
+
+#include "ggml-vulkan-shaders.hpp"
+
+#define VK_API_VERSION VK_API_VERSION_1_2
+
+#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
+
+#define VK_VENDOR_ID_AMD 0x1002
+#define VK_VENDOR_ID_APPLE 0x106b
+#define VK_VENDOR_ID_INTEL 0x8086
+#define VK_VENDOR_ID_NVIDIA 0x10de
+
+#define VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN 0
+#define VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI 1
+#define VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE 2
+
+#define GGML_VK_MAX_NODES 8192
+
+#define MAX_VK_BUFFERS 256
+
+#ifndef K_QUANTS_PER_ITERATION
+#define K_QUANTS_PER_ITERATION 1
+#else
+static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
+#endif
+
+#define VK_CHECK(err, msg) \
+ do { \
+ vk::Result err_ = (err); \
+ if (err_ != vk::Result::eSuccess) { \
+ fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n", \
+ #err, to_string(err_).c_str(), __FILE__, __LINE__); \
+ exit(1); \
+ } \
+ } while (0)
+
+#ifdef GGML_VULKAN_DEBUG
+#define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl
+#else
+#define VK_LOG_DEBUG(msg) ((void) 0)
+#endif // GGML_VULKAN_DEBUG
+
+struct ggml_backend_vk_context;
+
+struct vk_queue {
+ uint32_t queue_family_index;
+ vk::Queue queue;
+ vk::CommandPool pool;
+ uint32_t cmd_buffer_idx;
+ std::vector<vk::CommandBuffer> cmd_buffers;
+
+ vk::PipelineStageFlags stage_flags;
+};
+
+struct vk_pipeline_struct {
+ std::string name;
+ vk::ShaderModule shader_module;
+ vk::DescriptorSetLayout dsl;
+ std::vector<vk::DescriptorPool> descriptor_pools;
+ std::vector<vk::DescriptorSet> descriptor_sets;
+ uint32_t descriptor_set_idx;
+ vk::PipelineLayout layout;
+ vk::Pipeline pipeline;
+ uint32_t push_constant_size;
+ uint32_t parameter_count;
+ std::array<uint32_t, 3> wg_denoms;
+ uint32_t align;
+};
+
+typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
+typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref;
+
+static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
+
+struct vk_matmul_pipeline_struct {
+ vk_pipeline l, m, s;
+ vk_pipeline a_l, a_m, a_s;
+};
+
+typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
+
+struct vk_device_struct;
+typedef std::shared_ptr<vk_device_struct> vk_device;
+typedef std::weak_ptr<vk_device_struct> vk_device_ref;
+
+struct vk_buffer_struct;
+typedef std::shared_ptr<vk_buffer_struct> vk_buffer;
+typedef std::weak_ptr<vk_buffer_struct> vk_buffer_ref;
+
+struct ggml_backend_vk_buffer_type_context {
+ std::string name;
+ vk_device device;
+};
+
+GGML_CALL static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft);
+GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
+GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft);
+GGML_CALL static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft);
+GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor);
+static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {
+ /* .get_name = */ ggml_backend_vk_buffer_type_name,
+ /* .alloc_buffer = */ ggml_backend_vk_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_vk_buffer_type_get_alignment,
+ /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size,
+ /* .get_alloc_size = */ ggml_backend_vk_buffer_type_get_alloc_size,
+ /* .is_host = */ NULL,
+};
+
+#ifdef GGML_VULKAN_MEMORY_DEBUG
+class vk_memory_logger;
+#endif
+static void ggml_vk_destroy_buffer(vk_buffer& buf);
+
+struct vk_device_struct {
+ std::mutex mutex;
+
+ vk::PhysicalDevice physical_device;
+ vk::PhysicalDeviceProperties properties;
+ std::string name;
+ uint64_t max_memory_allocation_size;
+ bool fp16;
+ vk::Device device;
+ uint32_t vendor_id;
+ vk_queue compute_queue;
+ vk_queue transfer_queue;
+ bool single_queue;
+ uint32_t descriptor_set_mode;
+ uint32_t subgroup_size;
+ bool uma;
+
+ size_t idx;
+
+ vk_matmul_pipeline pipeline_matmul_f32;
+ vk_matmul_pipeline pipeline_matmul_f32_f16;
+ vk_matmul_pipeline pipeline_matmul_f16;
+ vk_matmul_pipeline pipeline_matmul_f16_f32;
+ vk_pipeline pipeline_matmul_split_k_reduce;
+
+ vk_matmul_pipeline pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
+
+ vk_matmul_pipeline pipeline_matmul_id_f32;
+ vk_matmul_pipeline pipeline_matmul_id_f16;
+ vk_matmul_pipeline pipeline_matmul_id_f16_f32;
+
+ vk_matmul_pipeline pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
+
+ vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
+ vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
+ vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT];
+ vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
+
+ vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
+ vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
+ vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
+ vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
+ vk_pipeline pipeline_mul_f32;
+ vk_pipeline pipeline_div_f32;
+ vk_pipeline pipeline_add_f32;
+ vk_pipeline pipeline_scale_f32;
+ vk_pipeline pipeline_sqr_f32;
+ vk_pipeline pipeline_clamp_f32;
+ vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
+ vk_pipeline pipeline_norm_f32;
+ vk_pipeline pipeline_rms_norm_f32;
+ vk_pipeline pipeline_gelu_f32;
+ vk_pipeline pipeline_silu_f32;
+ vk_pipeline pipeline_relu_f32;
+ vk_pipeline pipeline_diag_mask_inf_f32;
+ vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
+ vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
+ vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
+ vk_pipeline pipeline_argsort_f32;
+ vk_pipeline pipeline_sum_rows_f32;
+
+ std::vector<vk_pipeline_ref> pipelines;
+
+ std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
+
+ vk::Fence fence;
+ vk_buffer sync_staging;
+
+ ggml_backend_buffer_type buffer_type;
+
+#ifdef GGML_VULKAN_MEMORY_DEBUG
+ std::unique_ptr<vk_memory_logger> memory_logger;
+#endif
+
+ ~vk_device_struct() {
+ VK_LOG_DEBUG("destroy device " << name);
+
+ device.destroyFence(fence);
+
+ ggml_vk_destroy_buffer(sync_staging);
+
+ device.destroyCommandPool(compute_queue.pool);
+ if (!single_queue) {
+ device.destroyCommandPool(transfer_queue.pool);
+ }
+
+ for (auto& pipeline : pipelines) {
+ if (pipeline.expired()) {
+ continue;
+ }
+
+ vk_pipeline pl = pipeline.lock();
+ ggml_vk_destroy_pipeline(device, pl);
+ }
+ pipelines.clear();
+
+ device.destroy();
+ }
+};
+
+struct vk_buffer_struct {
+ vk::Buffer buffer;
+ vk::DeviceMemory device_memory;
+ vk::MemoryPropertyFlags memory_property_flags;
+ void * ptr;
+ size_t size = 0;
+
+ vk_device device;
+
+ ~vk_buffer_struct() {
+ if (size == 0) {
+ return;
+ }
+ VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")");
+
+ device->device.freeMemory(device_memory);
+ device->device.destroyBuffer(buffer);
+ }
+};
+
+struct vk_subbuffer {
+ vk_buffer buffer;
+ uint64_t offset;
+ uint64_t size;
+};
+
+struct vk_semaphore {
+ vk::Semaphore s;
+ uint64_t value;
+};
+
+struct vk_submission {
+ vk::CommandBuffer buffer;
+ std::vector<vk_semaphore> wait_semaphores;
+ std::vector<vk_semaphore> signal_semaphores;
+};
+
+typedef std::vector<vk_submission> vk_sequence;
+
+struct vk_mat_mat_push_constants {
+ uint32_t M; uint32_t N; uint32_t K;
+ uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
+ uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
+ uint32_t k_split;
+ uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
+};
+struct vk_mat_vec_push_constants {
+ uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
+ uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
+ uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
+};
+
+struct vk_mat_mat_id_push_constants {
+ uint32_t M; uint32_t N; uint32_t K;
+ uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
+ uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
+ uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
+};
+struct vk_mat_vec_id_push_constants {
+ uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
+ uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
+ uint32_t nei0; uint32_t ne11;
+};
+
+struct vk_op_push_constants {
+ uint32_t KX;
+ uint32_t KY;
+ float param1;
+ float param2;
+};
+
+struct vk_op_unary_push_constants {
+ uint32_t ne;
+ uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
+ uint32_t d_offset;
+ float param1; float param2;
+};
+
+struct vk_op_binary_push_constants {
+ uint32_t ne;
+ uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
+ uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
+ uint32_t d_offset;
+ float param1; float param2;
+};
+
+struct vk_op_diag_mask_push_constants {
+ uint32_t ncols;
+ uint32_t rows_per_channel;
+ int32_t n_past;
+};
+
+struct vk_op_rope_push_constants {
+ uint32_t ncols;
+ uint32_t n_dims;
+ float freq_scale;
+ uint32_t p_delta_rows;
+ float freq_base;
+ float ext_factor;
+ float attn_factor;
+ float corr_dims[2];
+ float theta_scale;
+ uint32_t has_ff;
+};
+
+struct vk_op_soft_max_push_constants {
+ uint32_t KX;
+ uint32_t KY;
+ float scale;
+ float max_bias;
+ float m0;
+ float m1;
+ uint32_t n_head_log2;
+};
+
+struct vk_op_argsort_push_constants {
+ uint32_t ncols;
+ uint32_t ncols_pad;
+ int32_t order;
+};
+
+// Allow pre-recording command buffers
+struct vk_staging_memcpy {
+ vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
+
+ void * dst;
+ const void * src;
+ size_t n;
+};
+
+struct vk_context {
+ size_t idx;
+
+ vk_submission * s;
+ std::vector<vk_sequence> seqs;
+
+ ggml_tensor * exit_tensor;
+
+ std::vector<vk_staging_memcpy> in_memcpys;
+ std::vector<vk_staging_memcpy> out_memcpys;
+
+ vk_queue * q;
+};
+
+struct ggml_tensor_extra_gpu {
+ size_t ctx_idx;
+
+ vk_buffer_ref buffer_gpu;
+ uint64_t offset;
+
+ void reset() {
+ ctx_idx = 0;
+ buffer_gpu.reset();
+ offset = 0;
+ }
+};
+
+struct ggml_vk_garbage_collector {
+ std::vector<vk_semaphore> tl_semaphores;
+ std::vector<vk_semaphore> semaphores;
+ std::vector<vk::Event> events;
+ std::vector<vk_buffer> temp_buffers;
+ std::vector<vk_context> contexts;
+};
+
+#if defined(GGML_VULKAN_MEMORY_DEBUG) || defined(GGML_VULKAN_DEBUG)
+#define VK_LOG_MEMORY(msg) std::cerr << "ggml_vulkan memory: " << msg << std::endl
+
+static std::string format_size(size_t size) {
+ const size_t kib = 1024;
+ const size_t mib = kib * 1024;
+ const size_t gib = mib * 1024;
+
+ std::ostringstream oss;
+ oss << std::fixed << std::setprecision(2);
+
+ if (size >= gib) {
+ oss << static_cast<double>(size) / gib << " GiB";
+ } else if (size >= mib) {
+ oss << static_cast<double>(size) / mib << " MiB";
+ } else if (size >= kib) {
+ oss << static_cast<double>(size) / kib << " KiB";
+ } else {
+ oss << size << " B";
+ }
+
+ return oss.str();
+}
+
+static std::mutex log_mutex;
+
+class vk_memory_logger {
+public:
+ vk_memory_logger(): total_device(0), total_host(0) {}
+ void log_allocation(vk_buffer_ref buf_ref, size_t size);
+ void log_deallocation(vk_buffer_ref buf_ref);
+
+private:
+ std::map<vk::Buffer, size_t> allocations; // Track allocations
+ size_t total_device;
+ size_t total_host;
+};
+#else
+#define VK_LOG_MEMORY(msg) ((void) 0)
+#endif // GGML_VULKAN_MEMORY_DEBUG
+
+struct ggml_backend_vk_context {
+ std::string name;
+
+ vk_device device;
+
+ size_t semaphore_idx, event_idx;
+ ggml_vk_garbage_collector gc;
+ size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
+ vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
+ vk::Fence fence;
+ vk_buffer staging;
+ size_t staging_size;
+ size_t staging_offset;
+
+ vk_buffer buffer_pool[MAX_VK_BUFFERS];
+
+ vk_context * compute_ctx;
+ vk_context * transfer_ctx;
+};
+
+#ifdef GGML_VULKAN_MEMORY_DEBUG
+void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) {
+ std::lock_guard<std::mutex> guard(log_mutex);
+ vk_buffer buf = buf_ref.lock();
+ const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
+ const std::string type = device ? "device" : "host";
+ allocations[buf->buffer] = size;
+ total_device += device ? size : 0;
+ total_host += device ? 0 : size;
+ VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
+}
+
+void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
+ if (buf_ref.expired() || buf_ref.lock()->size == 0) {
+ return;
+ }
+
+ std::lock_guard<std::mutex> guard(log_mutex);
+ vk_buffer buf = buf_ref.lock();
+ const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
+ std::string type = device ? "device" : "host";
+ auto it = allocations.find(buf->buffer);
+ total_device -= device ? it->second : 0;
+ total_host -= device ? 0 : it->second;
+ if (it != allocations.end()) {
+ VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
+ allocations.erase(it);
+ } else {
+ VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer);
+ }
+}
+#endif // GGML_VULKAN_MEMORY_DEBUG
+
+struct vk_instance_t {
+ vk::Instance instance;
+
+ std::vector<size_t> device_indices;
+ vk_device devices[GGML_VK_MAX_DEVICES];
+};
+
+static bool vk_instance_initialized = false;
+static vk_instance_t vk_instance;
+
+#ifdef GGML_VULKAN_CHECK_RESULTS
+static size_t vk_skip_checks;
+static size_t vk_output_tensor;
+
+static void ggml_vk_print_tensor(ggml_backend * ctx, const ggml_tensor * tensor, const char * name);
+static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor);
+static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * tensor);
+#endif
+
+typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend);
+
+static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
+ VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
+ GGML_ASSERT(parameter_count > 0);
+ GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
+
+ std::lock_guard<std::mutex> guard(device->mutex);
+
+ pipeline = std::make_shared<vk_pipeline_struct>();
+ pipeline->name = name;
+ pipeline->parameter_count = parameter_count;
+ pipeline->push_constant_size = push_constant_size;
+ pipeline->wg_denoms = wg_denoms;
+ pipeline->align = align;
+
+ vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
+ pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
+
+ std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
+ std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
+ for (uint32_t i = 0; i < parameter_count; i++) {
+ dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute});
+ dsl_binding_flags.push_back({});
+ }
+
+ vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags };
+
+ vk::PushConstantRange pcr(
+ vk::ShaderStageFlagBits::eCompute,
+ 0,
+ pipeline->push_constant_size
+ );
+
+ vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
+ {},
+ dsl_binding);
+ descriptor_set_layout_create_info.setPNext(&dslbfci);
+ pipeline->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
+
+ // Check if device supports multiple descriptors per pool
+ if (device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN) {
+ const uint32_t alloc_count = 2;
+
+ // Try allocating multiple sets from one pool
+ // This fails on AMD for some reason, so add a fall back to allocating one pool per set
+ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
+ vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, alloc_count, descriptor_pool_size);
+ vk::DescriptorPool pool = device->device.createDescriptorPool(descriptor_pool_create_info);
+
+ std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
+ for (uint32_t i = 0; i < alloc_count; i++) {
+ layouts[i] = pipeline->dsl;
+ }
+ try {
+ vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pool, alloc_count, layouts.data());
+ std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
+ } catch(vk::OutOfPoolMemoryError const&) {
+ device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE;
+ }
+
+ device->device.destroyDescriptorPool(pool);
+ }
+
+ if (device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
+ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
+ vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 128, descriptor_pool_size);
+ pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
+ }
+
+ pipeline->descriptor_set_idx = 0;
+
+ vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr);
+ pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info);
+
+ std::vector<vk::SpecializationMapEntry> specialization_entries(specialization_constants.size());
+
+ for (size_t i = 0; i < specialization_constants.size(); i++) {
+ specialization_entries[i].constantID = i;
+ specialization_entries[i].offset = i * sizeof(uint32_t);
+ specialization_entries[i].size = sizeof(uint32_t);
+ }
+
+ vk::SpecializationInfo specialization_info(
+ specialization_entries.size(),
+ specialization_entries.data(),
+ specialization_constants.size() * sizeof(uint32_t),
+ specialization_constants.data()
+ );
+
+ vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
+ vk::PipelineShaderStageCreateFlags(),
+ vk::ShaderStageFlagBits::eCompute,
+ pipeline->shader_module,
+ entrypoint.c_str(),
+ &specialization_info);
+ vk::ComputePipelineCreateInfo compute_pipeline_create_info(
+ vk::PipelineCreateFlags(),
+ pipeline_shader_create_info,
+ pipeline->layout);
+ pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
+
+ device->pipelines.push_back(pipeline);
+}
+
+static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
+ VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")");
+ for (auto& pool : pipeline->descriptor_pools) {
+ device.destroyDescriptorPool(pool);
+ }
+ pipeline->descriptor_pools.clear();
+ pipeline->descriptor_sets.clear();
+ pipeline->descriptor_set_idx = 0;
+
+ device.destroyDescriptorSetLayout(pipeline->dsl);
+
+ device.destroyPipelineLayout(pipeline->layout);
+
+ device.destroyShaderModule(pipeline->shader_module);
+
+ device.destroyPipeline(pipeline->pipeline);
+}
+
+static void ggml_pipeline_allocate_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) {
+ VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")");
+ if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) {
+ // Enough descriptors are available
+ return;
+ }
+
+ std::lock_guard<std::mutex> guard(device->mutex);
+
+ if (device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
+ const uint32_t alloc_count = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size();
+
+ std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
+ for (uint32_t i = 0; i < alloc_count; i++) {
+ layouts[i] = pipeline->dsl;
+ }
+ vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[0], alloc_count, layouts.data());
+ std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
+ pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end());
+ } else {
+ for (uint32_t i = pipeline->descriptor_sets.size(); i < pipeline->descriptor_set_idx + n; i++) {
+ vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
+ vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 1, descriptor_pool_size);
+ pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
+
+ vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[i], 1, &pipeline->dsl);
+ std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
+ pipeline->descriptor_sets.push_back(sets[0]);
+ }
+ }
+}
+
+static void ggml_pipeline_cleanup(vk_pipeline& pipeline) {
+ VK_LOG_DEBUG("ggml_pipeline_cleanup(" << pipeline->name << ")");
+ pipeline->descriptor_set_idx = 0;
+}
+
+static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_queue& q) {
+ VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()");
+ std::lock_guard<std::mutex> guard(device->mutex);
+
+ if (q.cmd_buffers.size() > q.cmd_buffer_idx) {
+ // Reuse command buffer
+ return q.cmd_buffers[q.cmd_buffer_idx++];
+ }
+
+ vk::CommandBufferAllocateInfo command_buffer_alloc_info(
+ q.pool,
+ vk::CommandBufferLevel::ePrimary,
+ 1);
+ const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
+ auto buf = cmd_buffers.front();
+
+ q.cmd_buffers.push_back(buf);
+ q.cmd_buffer_idx++;
+
+ return buf;
+}
+
+static vk_submission ggml_vk_create_submission(vk_device& device, vk_queue& q, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
+ VK_LOG_DEBUG("ggml_vk_create_submission()");
+ vk_submission s;
+ s.buffer = ggml_vk_create_cmd_buffer(device, q);
+ s.wait_semaphores = std::move(wait_semaphores);
+ s.signal_semaphores = std::move(signal_semaphores);
+ return s;
+}
+
+static void ggml_vk_submit(vk_context * ctx, vk::Fence fence) {
+ VK_LOG_DEBUG("ggml_vk_submit(" << ctx->seqs.size() << ", " << fence << ")");
+ if (ctx->seqs.empty()) {
+ return;
+ }
+
+ std::vector<std::vector<uint64_t>> tl_wait_vals;
+ std::vector<std::vector<uint64_t>> tl_signal_vals;
+ std::vector<std::vector<vk::Semaphore>> tl_wait_semaphores;
+ std::vector<std::vector<vk::Semaphore>> tl_signal_semaphores;
+ std::vector<vk::TimelineSemaphoreSubmitInfo> tl_submit_infos;
+ std::vector<vk::SubmitInfo> submit_infos;
+ int idx = -1;
+ std::vector<std::vector<vk::PipelineStageFlags>> stage_flags;
+
+ size_t reserve = 0;
+
+ for (const auto& sequence : ctx->seqs) {
+ reserve += sequence.size();
+ }
+
+ // Pre-reserve vectors to prevent reallocation, which invalidates pointers
+ tl_wait_semaphores.reserve(reserve);
+ tl_wait_vals.reserve(reserve);
+ tl_signal_semaphores.reserve(reserve);
+ tl_signal_vals.reserve(reserve);
+ tl_submit_infos.reserve(reserve);
+ submit_infos.reserve(reserve);
+ stage_flags.reserve(reserve);
+
+ for (const auto& sequence : ctx->seqs) {
+ for (const auto& submission : sequence) {
+ stage_flags.push_back({});
+ idx++;
+ tl_wait_vals.push_back({});
+ tl_wait_semaphores.push_back({});
+ tl_signal_vals.push_back({});
+ tl_signal_semaphores.push_back({});
+ for (size_t i = 0; i < submission.wait_semaphores.size(); i++) {
+ stage_flags[idx].push_back(ctx->q->stage_flags);
+ tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value);
+ tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s);
+ }
+ for (size_t i = 0; i < submission.signal_semaphores.size(); i++) {
+ tl_signal_vals[idx].push_back(submission.signal_semaphores[i].value);
+ tl_signal_semaphores[idx].push_back(submission.signal_semaphores[i].s);
+ }
+ tl_submit_infos.push_back({
+ (uint32_t) submission.wait_semaphores.size(),
+ tl_wait_vals[idx].data(),
+ (uint32_t) submission.signal_semaphores.size(),
+ tl_signal_vals[idx].data(),
+ });
+ tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo;
+ tl_submit_infos[idx].pNext = nullptr;
+ vk::SubmitInfo si{
+ (uint32_t) submission.wait_semaphores.size(),
+ tl_wait_semaphores[idx].data(),
+ stage_flags[idx].data(),
+ 1,
+ &submission.buffer,
+ (uint32_t) submission.signal_semaphores.size(),
+ tl_signal_semaphores[idx].data(),
+ };
+ si.setPNext(&tl_submit_infos[idx]);
+ submit_infos.push_back(si);
+ }
+ }
+
+ ctx->q->queue.submit(submit_infos, fence);
+
+ ctx->seqs.clear();
+}
+
+static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyProperties>& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) {
+ VK_LOG_DEBUG("ggml_vk_find_queue_family_index()");
+ const uint32_t qfsize = queue_family_props.size();
+
+ // Try with avoid preferences first
+ for (uint32_t i = 0; i < qfsize; i++) {
+ if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) {
+ return i;
+ }
+ }
+
+ // Fall back to only required
+ for (size_t i = 0; i < qfsize; i++) {
+ if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) {
+ return i;
+ }
+ }
+
+ // Fall back to reusing compute queue
+ for (size_t i = 0; i < qfsize; i++) {
+ if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) {
+ return i;
+ }
+ }
+
+ // Fall back to ignoring min_num_queries
+ for (size_t i = 0; i < qfsize; i++) {
+ if (queue_family_props[i].queueFlags & required) {
+ return i;
+ }
+ }
+
+ // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations.
+ // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional.
+ if (compute_index >= 0) {
+ return compute_index;
+ }
+
+ std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl;
+
+ for(auto &q_family : queue_family_props) {
+ std::cerr << "Queue number: " + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl;
+ }
+ abort();
+}
+
+static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags) {
+ VK_LOG_DEBUG("ggml_vk_create_queue()");
+ std::lock_guard<std::mutex> guard(device->mutex);
+
+ q.queue_family_index = queue_family_index;
+
+ vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index);
+ q.pool = device->device.createCommandPool(command_pool_create_info_compute);
+
+ q.cmd_buffer_idx = 0;
+
+ q.queue = device->device.getQueue(queue_family_index, queue_index);
+
+ q.stage_flags = stage_flags;
+}
+
+static vk_context * ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) {
+ VK_LOG_DEBUG("ggml_vk_create_context()");
+ ctx->gc.contexts.emplace_back();
+ vk_context * result = &ctx->gc.contexts[ctx->gc.contexts.size() - 1];
+ memset((void *) result, 0, sizeof(vk_context));
+ result->idx = ctx->gc.contexts.size() - 1;
+ result->q = &q;
+ return result;
+}
+
+static vk_context * ggml_vk_create_temporary_context(vk_queue& q) {
+ VK_LOG_DEBUG("ggml_vk_create_temporary_context()");
+ vk_context * result = new vk_context;
+ memset((void *) result, 0, sizeof(vk_context));
+ result->idx = 0;
+ result->q = &q;
+ return result;
+}
+
+static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) {
+ VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
+ vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 };
+ vk::SemaphoreCreateInfo ci{};
+ ci.setPNext(&tci);
+ vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
+ ctx->gc.semaphores.push_back({ semaphore, 0 });
+ return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1];
+}
+
+static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) {
+ VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
+ if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) {
+ vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
+ vk::SemaphoreCreateInfo ci{};
+ ci.setPNext(&tci);
+ vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
+ ctx->gc.tl_semaphores.push_back({ semaphore, 0 });
+ }
+ return &ctx->gc.tl_semaphores[ctx->semaphore_idx++];
+}
+
+static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) {
+ if (ctx->event_idx >= ctx->gc.events.size()) {
+ ctx->gc.events.push_back(ctx->device->device.createEvent({}));
+ }
+ return ctx->gc.events[ctx->event_idx++];
+}
+
+static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) {
+ VK_LOG_DEBUG("ggml_vk_queue_cleanup()");
+ std::lock_guard<std::mutex> guard(device->mutex);
+
+ // Requires command buffers to be done
+ device->device.resetCommandPool(q.pool);
+ q.cmd_buffer_idx = 0;
+}
+
+static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
+ for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) {
+ vk::MemoryType memory_type = mem_props->memoryTypes[i];
+ if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) &&
+ (flags & memory_type.propertyFlags) == flags &&
+ mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) {
+ return static_cast<int32_t>(i);
+ }
+ }
+ return UINT32_MAX;
+}
+
+static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
+ VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")");
+ std::lock_guard<std::mutex> guard(device->mutex);
+
+ vk_buffer buf = std::make_shared<vk_buffer_struct>();
+
+ if (size == 0) {
+ buf->size = 0;
+ return buf;
+ }
+
+ buf->size = size;
+ vk::BufferCreateInfo buffer_create_info{
+ vk::BufferCreateFlags(),
+ size,
+ vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst,
+ vk::SharingMode::eExclusive,
+ 0,
+ nullptr,
+ };
+
+ buf->buffer = device->device.createBuffer(buffer_create_info);
+
+ vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer);
+
+ vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
+
+ uint32_t memory_type_index = UINT32_MAX;
+
+ memory_type_index = find_properties(&mem_props, &mem_req, req_flags);
+ buf->memory_property_flags = req_flags;
+
+ if (memory_type_index == UINT32_MAX && fallback_flags) {
+ memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags);
+ buf->memory_property_flags = fallback_flags;
+ }
+
+ if (memory_type_index == UINT32_MAX) {
+ device->device.destroyBuffer(buf->buffer);
+ buf->size = 0;
+ throw vk::OutOfDeviceMemoryError("No suitable memory type found");
+ }
+
+ try {
+ buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
+ } catch (const vk::SystemError& e) {
+ // Out of Host/Device memory, clean up buffer
+ device->device.destroyBuffer(buf->buffer);
+ buf->size = 0;
+ throw e;
+ }
+ buf->ptr = nullptr;
+
+ if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
+ buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
+ }
+
+ device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
+
+ buf->device = device;
+
+#ifdef GGML_VULKAN_MEMORY_DEBUG
+ device->memory_logger->log_allocation(buf, size);
+#endif
+
+ return buf;
+}
+
+static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
+ try {
+ return ggml_vk_create_buffer(device, size, req_flags, fallback_flags);
+ } catch (const vk::SystemError& e) {
+ std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
+ std::cerr << "ggml_vulkan: " << e.what() << std::endl;
+ throw e;
+ }
+}
+
+static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
+ vk_buffer buf;
+ try {
+ if (device->uma) {
+ // Fall back to host memory type
+ buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
+ } else {
+ buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ }
+ } catch (const vk::SystemError& e) {
+ std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
+ std::cerr << "ggml_vulkan: " << e.what() << std::endl;
+ throw e;
+ }
+
+ return buf;
+}
+
+static void ggml_vk_destroy_buffer(vk_buffer& buf) {
+ if (buf == nullptr) {
+ return;
+ }
+
+#ifdef GGML_VULKAN_MEMORY_DEBUG
+ if (buf->device != nullptr) {
+ buf->device->memory_logger->log_deallocation(buf);
+ }
+#endif
+
+ buf.reset();
+}
+
+static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) {
+ return { buf, 0, VK_WHOLE_SIZE };
+}
+
+static void ggml_vk_sync_buffers(vk_context * ctx) {
+ VK_LOG_DEBUG("ggml_vk_sync_buffers()");
+ const std::vector<vk::MemoryBarrier> mem_barriers{ { { vk::AccessFlagBits::eMemoryRead | vk::AccessFlagBits::eMemoryWrite }, { vk::AccessFlagBits::eMemoryRead | vk::AccessFlagBits::eMemoryWrite } } };
+
+ ctx->s->buffer.pipelineBarrier(
+ ctx->q->stage_flags,
+ ctx->q->stage_flags,
+ {},
+ mem_barriers,
+ {},
+ {}
+ );
+}
+
+static void ggml_vk_wait_events(vk_context * ctx, std::vector<vk::Event>&& events) {
+ VK_LOG_DEBUG("ggml_vk_wait_events()");
+ if (events.empty()) {
+ return;
+ }
+
+ ctx->s->buffer.waitEvents(
+ events,
+ ctx->q->stage_flags,
+ ctx->q->stage_flags,
+ {},
+ {},
+ {}
+ );
+}
+
+static void ggml_vk_load_shaders(vk_device& device) {
+ VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
+
+ // mulmat
+ std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
+ std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
+ std::initializer_list<uint32_t> warptile_s = { device->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
+
+ std::initializer_list<uint32_t> warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
+ std::initializer_list<uint32_t> warptile_mmq_m = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
+ std::initializer_list<uint32_t> warptile_mmq_s = { device->subgroup_size, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
+
+ std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
+ std::array<uint32_t, 3> m_wg_denoms = { 64, 64, 1 };
+ std::array<uint32_t, 3> s_wg_denoms = { 32, 32, 1 };
+
+ uint32_t l_align = 128;
+ uint32_t m_align = 64;
+ uint32_t s_align = 32;
+
+ device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_matmul_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_matmul_f16 = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
+
+ device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_matmul_id_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_matmul_id_f16 = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
+ device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
+
+ if (device->fp16) {
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->l, "matmul_f32_f16_l", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->m, "matmul_f32_f16_m", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->s, "matmul_f32_f16_s", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_l, "matmul_f32_f16_aligned_l", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_m, "matmul_f32_f16_aligned_m", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_s, "matmul_f32_f16_aligned_s", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->l, "matmul_iq4_nl_f32_l", matmul_iq4_nl_f32_len, matmul_iq4_nl_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->m, "matmul_iq4_nl_f32_m", matmul_iq4_nl_f32_len, matmul_iq4_nl_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->s, "matmul_iq4_nl_f32_s", matmul_iq4_nl_f32_len, matmul_iq4_nl_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_l, "matmul_iq4_nl_f32_aligned_l", matmul_iq4_nl_f32_aligned_len, matmul_iq4_nl_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_m, "matmul_iq4_nl_f32_aligned_m", matmul_iq4_nl_f32_aligned_len, matmul_iq4_nl_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_s, "matmul_iq4_nl_f32_aligned_s", matmul_iq4_nl_f32_aligned_len, matmul_iq4_nl_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_1_f32_l", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_1_f32_m", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_1_f32_s", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_1_f32_aligned_l", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_1_f32_aligned_m", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_1_f32_aligned_s", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->l, "matmul_id_iq4_nl_f32_l", matmul_id_iq4_nl_f32_len, matmul_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->m, "matmul_id_iq4_nl_f32_m", matmul_id_iq4_nl_f32_len, matmul_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->s, "matmul_id_iq4_nl_f32_s", matmul_id_iq4_nl_f32_len, matmul_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_l, "matmul_id_iq4_nl_f32_aligned_l", matmul_id_iq4_nl_f32_aligned_len, matmul_id_iq4_nl_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_m, "matmul_id_iq4_nl_f32_aligned_m", matmul_id_iq4_nl_f32_aligned_len, matmul_id_iq4_nl_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_s, "matmul_id_iq4_nl_f32_aligned_s", matmul_id_iq4_nl_f32_aligned_len, matmul_id_iq4_nl_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ } else {
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->l, "matmul_f32_f16_l", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->m, "matmul_f32_f16_m", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->s, "matmul_f32_f16_s", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_l, "matmul_f32_f16_aligned_l", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_m, "matmul_f32_f16_aligned_m", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_s, "matmul_f32_f16_aligned_s", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->l, "matmul_iq4_nl_f32_l", matmul_iq4_nl_f32_fp32_len, matmul_iq4_nl_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->m, "matmul_iq4_nl_f32_m", matmul_iq4_nl_f32_fp32_len, matmul_iq4_nl_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->s, "matmul_iq4_nl_f32_s", matmul_iq4_nl_f32_fp32_len, matmul_iq4_nl_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_l, "matmul_iq4_nl_f32_aligned_l", matmul_iq4_nl_f32_aligned_fp32_len, matmul_iq4_nl_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_m, "matmul_iq4_nl_f32_aligned_m", matmul_iq4_nl_f32_aligned_fp32_len, matmul_iq4_nl_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_s, "matmul_iq4_nl_f32_aligned_s", matmul_iq4_nl_f32_aligned_fp32_len, matmul_iq4_nl_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_1_f32_l", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_1_f32_m", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_1_f32_s", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_1_f32_aligned_l", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_1_f32_aligned_m", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_1_f32_aligned_s", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->l, "matmul_id_iq4_nl_f32_l", matmul_id_iq4_nl_f32_fp32_len, matmul_id_iq4_nl_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->m, "matmul_id_iq4_nl_f32_m", matmul_id_iq4_nl_f32_fp32_len, matmul_id_iq4_nl_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->s, "matmul_id_iq4_nl_f32_s", matmul_id_iq4_nl_f32_fp32_len, matmul_id_iq4_nl_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_l, "matmul_id_iq4_nl_f32_aligned_l", matmul_id_iq4_nl_f32_aligned_fp32_len, matmul_id_iq4_nl_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_m, "matmul_id_iq4_nl_f32_aligned_m", matmul_id_iq4_nl_f32_aligned_fp32_len, matmul_id_iq4_nl_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_s, "matmul_id_iq4_nl_f32_aligned_s", matmul_id_iq4_nl_f32_aligned_fp32_len, matmul_id_iq4_nl_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
+ }
+
+ // mul mat vec
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+
+ // dequant shaders
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+
+ // get_rows
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
+
+ ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+}
+
+static vk_device ggml_vk_get_device(size_t idx) {
+ VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
+
+ if (vk_instance.devices[idx] == nullptr) {
+ VK_LOG_DEBUG("Initializing new vk_device");
+ vk_device device = std::make_shared<vk_device_struct>();
+ vk_instance.devices[idx] = device;
+
+#ifdef GGML_VULKAN_MEMORY_DEBUG
+ device->memory_logger = std::unique_ptr<vk_memory_logger>(new vk_memory_logger());
+#endif
+
+ size_t dev_num = vk_instance.device_indices[idx];
+
+ std::vector<vk::PhysicalDevice> physical_devices = vk_instance.instance.enumeratePhysicalDevices();
+
+ if (dev_num >= physical_devices.size()) {
+ std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
+ throw std::runtime_error("Device not found");
+ }
+
+ device->physical_device = physical_devices[dev_num];
+ const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
+
+ bool maintenance4_support = false;
+
+ // Check if maintenance4 is supported
+ for (const auto& properties : ext_props) {
+ if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
+ maintenance4_support = true;
+ }
+ }
+
+ vk::PhysicalDeviceProperties2 props2;
+ vk::PhysicalDeviceMaintenance3Properties props3;
+ vk::PhysicalDeviceMaintenance4Properties props4;
+ vk::PhysicalDeviceSubgroupProperties subgroup_props;
+ props2.pNext = &props3;
+ props3.pNext = &subgroup_props;
+ if (maintenance4_support) {
+ subgroup_props.pNext = &props4;
+ }
+ device->physical_device.getProperties2(&props2);
+ device->properties = props2.properties;
+
+ const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
+
+ if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
+ device->max_memory_allocation_size = std::stoi(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
+ } else if (maintenance4_support) {
+ device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
+ } else {
+ device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
+ }
+
+ device->vendor_id = device->properties.vendorID;
+ device->subgroup_size = subgroup_props.subgroupSize;
+ device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
+
+ bool fp16_storage = false;
+ bool fp16_compute = false;
+
+ for (const auto& properties : ext_props) {
+ if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
+ fp16_storage = true;
+ } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
+ fp16_compute = true;
+ }
+ }
+
+ const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
+ const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
+
+ device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
+
+ std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
+
+ // Try to find a non-graphics compute queue and transfer-focused queues
+ const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
+ const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
+
+ const float priorities[] = { 1.0f, 1.0f };
+ device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
+
+ std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
+ if (compute_queue_family_index != transfer_queue_family_index) {
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
+ } else if(!device->single_queue) {
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
+ } else {
+ device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
+ }
+ vk::DeviceCreateInfo device_create_info;
+ std::vector<const char *> device_extensions;
+ vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
+
+ VkPhysicalDeviceFeatures2 device_features2;
+ device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
+ device_features2.pNext = nullptr;
+ device_features2.features = (VkPhysicalDeviceFeatures)device_features;
+
+ VkPhysicalDeviceVulkan11Features vk11_features;
+ vk11_features.pNext = nullptr;
+ vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
+ device_features2.pNext = &vk11_features;
+
+ VkPhysicalDeviceVulkan12Features vk12_features;
+ vk12_features.pNext = nullptr;
+ vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
+ vk11_features.pNext = &vk12_features;
+
+ vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
+
+ device->fp16 = device->fp16 && vk12_features.shaderFloat16;
+
+ if (!vk11_features.storageBuffer16BitAccess) {
+ std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
+ throw std::runtime_error("Unsupported device");
+ }
+
+ device_extensions.push_back("VK_KHR_16bit_storage");
+
+#ifdef GGML_VULKAN_VALIDATE
+ device_extensions.push_back("VK_KHR_shader_non_semantic_info");
+#endif
+
+ if (device->fp16) {
+ device_extensions.push_back("VK_KHR_shader_float16_int8");
+ }
+ device->name = device->properties.deviceName.data();
+
+ device_create_info = {
+ vk::DeviceCreateFlags(),
+ device_queue_create_infos,
+ {},
+ device_extensions
+ };
+ device_create_info.setPNext(&device_features2);
+ device->device = device->physical_device.createDevice(device_create_info);
+
+ device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN;
+
+ // Queues
+ ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer });
+
+ // Shaders
+ ggml_vk_load_shaders(device);
+
+ if (!device->single_queue) {
+ const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
+ ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer });
+ } else {
+ // TODO: Use pointer or reference to avoid copy
+ device->transfer_queue = device->compute_queue;
+ }
+
+ device->buffer_type = {
+ /* .iface = */ ggml_backend_vk_buffer_type_interface,
+ /* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device },
+ };
+
+ device->fence = device->device.createFence({});
+
+ device->idx = idx;
+
+ return device;
+ }
+
+ return vk_instance.devices[idx];
+}
+
+
+static void ggml_vk_print_gpu_info(size_t idx) {
+ GGML_ASSERT(idx < vk_instance.device_indices.size());
+ size_t dev_num = vk_instance.device_indices[idx];
+ VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")");
+ GGML_ASSERT(vk_instance_initialized);
+
+ std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
+
+ if (dev_num >= devices.size()) {
+ std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
+ throw std::runtime_error("Device not found");
+ }
+
+ vk::PhysicalDevice physical_device = devices[dev_num];
+ std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
+
+ vk::PhysicalDeviceProperties2 props2;
+ vk::PhysicalDeviceMaintenance3Properties props3;
+ vk::PhysicalDeviceSubgroupProperties subgroup_props;
+ vk::PhysicalDeviceDriverProperties driver_props;
+ props2.pNext = &props3;
+ props3.pNext = &subgroup_props;
+ subgroup_props.pNext = &driver_props;
+ physical_device.getProperties2(&props2);
+
+ const size_t subgroup_size = subgroup_props.subgroupSize;
+ const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
+
+ bool fp16_storage = false;
+ bool fp16_compute = false;
+
+ for (auto properties : ext_props) {
+ if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
+ fp16_storage = true;
+ } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
+ fp16_compute = true;
+ }
+ }
+
+ const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
+ bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
+
+ bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
+
+ vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures();
+
+ VkPhysicalDeviceFeatures2 device_features2;
+ device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
+ device_features2.pNext = nullptr;
+ device_features2.features = (VkPhysicalDeviceFeatures)device_features;
+
+ VkPhysicalDeviceVulkan11Features vk11_features;
+ vk11_features.pNext = nullptr;
+ vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
+ device_features2.pNext = &vk11_features;
+
+ VkPhysicalDeviceVulkan12Features vk12_features;
+ vk12_features.pNext = nullptr;
+ vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
+ vk11_features.pNext = &vk12_features;
+
+ vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
+
+ fp16 = fp16 && vk12_features.shaderFloat16;
+
+ std::string device_name = props2.properties.deviceName.data();
+ std::cerr << GGML_VK_NAME << idx << ": " << device_name << " (" << driver_props.driverName << ") | uma: " << uma << " | fp16: " << fp16 << " | warp size: " << subgroup_size << std::endl;
+
+ if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
+ std::cerr << "ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want." << std::endl;
+ }
+}
+
+static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
+static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
+
+void ggml_vk_instance_init() {
+ if (vk_instance_initialized) {
+ return;
+ }
+ VK_LOG_DEBUG("ggml_vk_instance_init()");
+
+ vk_instance_initialized = true;
+
+ vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
+
+ const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
+ const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
+#ifdef __APPLE__
+ const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
+#endif
+
+ std::vector<const char*> layers;
+
+ if (validation_ext) {
+ layers.push_back("VK_LAYER_KHRONOS_validation");
+ }
+ std::vector<const char*> extensions;
+ if (validation_ext) {
+ extensions.push_back("VK_EXT_validation_features");
+ }
+#ifdef __APPLE__
+ if (portability_enumeration_ext) {
+ extensions.push_back("VK_KHR_portability_enumeration");
+ }
+#endif
+ vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
+#ifdef __APPLE__
+ if (portability_enumeration_ext) {
+ instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
+ }
+#endif
+
+ std::vector<vk::ValidationFeatureEnableEXT> features_enable;
+ vk::ValidationFeaturesEXT validation_features;
+
+ if (validation_ext) {
+ features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
+ validation_features = {
+ features_enable,
+ {},
+ };
+ validation_features.setPNext(nullptr);
+ instance_create_info.setPNext(&validation_features);
+
+ std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
+ }
+ vk_instance.instance = vk::createInstance(instance_create_info);
+
+ size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
+
+ // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
+ char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES");
+ if (devices_env != nullptr) {
+ std::string devices(devices_env);
+ std::replace(devices.begin(), devices.end(), ',', ' ');
+
+ std::stringstream ss(devices);
+ size_t tmp;
+ while (ss >> tmp) {
+ if(tmp >= num_available_devices) {
+ std::cerr << "ggml_vulkan: Invalid device index " << tmp << " in GGML_VK_VISIBLE_DEVICES." << std::endl;
+ throw std::runtime_error("Invalid Vulkan device index");
+ }
+ vk_instance.device_indices.push_back(tmp);
+ }
+ } else {
+ std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
+
+ // Make sure at least one device exists
+ if (devices.empty()) {
+ std::cerr << "ggml_vulkan: Error: No devices found." << std::endl;
+ GGML_ASSERT(false);
+ }
+
+ // Default to using all dedicated GPUs
+ for (size_t i = 0; i < devices.size(); i++) {
+ vk::PhysicalDeviceProperties2 new_props;
+ vk::PhysicalDeviceDriverProperties new_driver;
+ vk::PhysicalDeviceIDProperties new_id;
+ new_props.pNext = &new_driver;
+ new_driver.pNext = &new_id;
+ devices[i].getProperties2(&new_props);
+
+ if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) {
+ // Check if there are two physical devices corresponding to the same GPU
+ auto old_device = std::find_if(
+ vk_instance.device_indices.begin(),
+ vk_instance.device_indices.end(),
+ [&devices, &new_id](const size_t k){
+ vk::PhysicalDeviceProperties2 old_props;
+ vk::PhysicalDeviceIDProperties old_id;
+ old_props.pNext = &old_id;
+ devices[k].getProperties2(&old_props);
+ return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
+ }
+ );
+ if (old_device == vk_instance.device_indices.end()) {
+ vk_instance.device_indices.push_back(i);
+ } else {
+ // There can be two physical devices corresponding to the same GPU if there are 2 different drivers
+ // This can cause error when splitting layers aross the devices, need to keep only 1
+ VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID");
+
+ vk::PhysicalDeviceProperties2 old_props;
+ vk::PhysicalDeviceDriverProperties old_driver;
+ old_props.pNext = &old_driver;
+ devices[*old_device].getProperties2(&old_props);
+
+ std::map<vk::DriverId, int> driver_priorities {};
+ int old_priority = std::numeric_limits<int>::max();
+ int new_priority = std::numeric_limits<int>::max();
+
+ // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id
+ // Smaller number -> higher priority
+ switch (old_props.properties.vendorID) {
+ case VK_VENDOR_ID_AMD:
+ driver_priorities[vk::DriverId::eMesaRadv] = 1;
+ driver_priorities[vk::DriverId::eAmdOpenSource] = 2;
+ driver_priorities[vk::DriverId::eAmdProprietary] = 3;
+ break;
+ case VK_VENDOR_ID_INTEL:
+ driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1;
+ driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2;
+ break;
+ case VK_VENDOR_ID_NVIDIA:
+ driver_priorities[vk::DriverId::eNvidiaProprietary] = 1;
+#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235
+ driver_priorities[vk::DriverId::eMesaNvk] = 2;
+#endif
+ break;
+ }
+
+ if (driver_priorities.count(old_driver.driverID)) {
+ old_priority = driver_priorities[old_driver.driverID];
+ }
+ if (driver_priorities.count(new_driver.driverID)) {
+ new_priority = driver_priorities[new_driver.driverID];
+ }
+
+ if (new_priority < old_priority) {
+ auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device);
+ vk_instance.device_indices.erase(r, vk_instance.device_indices.end());
+ vk_instance.device_indices.push_back(i);
+
+ VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName);
+ }
+ else {
+ VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl);
+ }
+ }
+ }
+ }
+
+ // If no dedicated GPUs found, fall back to GPU 0
+ if (vk_instance.device_indices.empty()) {
+ vk_instance.device_indices.push_back(0);
+ }
+ }
+
+ std::cerr << "ggml_vulkan: Found " << vk_instance.device_indices.size() << " Vulkan devices:" << std::endl;
+
+ for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
+ ggml_vk_print_gpu_info(i);
+ }
+}
+
+static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
+ GGML_ASSERT(idx < vk_instance.device_indices.size());
+ VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")");
+ ggml_vk_instance_init();
+
+ ctx->name = GGML_VK_NAME + std::to_string(idx);
+
+ ctx->device = ggml_vk_get_device(idx);
+
+ ctx->semaphore_idx = 0;
+ ctx->event_idx = 0;
+
+ ctx->prealloc_size_x = 0;
+ ctx->prealloc_size_y = 0;
+ ctx->prealloc_size_split_k = 0;
+
+ ctx->fence = ctx->device->device.createFence({});
+
+ ctx->staging_size = 0;
+ ctx->staging_offset = 0;
+
+ ctx->compute_ctx = nullptr;
+ ctx->transfer_ctx = nullptr;
+
+#ifdef GGML_VULKAN_CHECK_RESULTS
+ const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
+ vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks));
+ const char* output_tensor = getenv("GGML_VULKAN_OUTPUT_TENSOR");
+ vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor));
+#endif
+}
+
+static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {
+ VK_LOG_DEBUG("ggml_vk_get_to_fp16()");
+ switch (type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ4_NL:
+ break;
+ default:
+ return nullptr;
+ }
+
+ return ctx->device->pipeline_dequant[type];
+}
+
+static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
+ VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline()");
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_matmul_f32;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
+ return ctx->device->pipeline_matmul_f32_f16;
+ }
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_matmul_f16_f32;
+ }
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+ return ctx->device->pipeline_matmul_f16;
+ }
+
+ GGML_ASSERT(src1_type == GGML_TYPE_F32);
+
+ switch (src0_type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ4_NL:
+ break;
+ default:
+ return nullptr;
+ }
+
+ return ctx->device->pipeline_dequant_mul_mat_mat[src0_type];
+}
+
+static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
+ VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
+ GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16);
+
+ switch (a_type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ4_NL:
+ break;
+ default:
+ return nullptr;
+ }
+
+ return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
+}
+
+static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
+ VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()");
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_matmul_id_f32;
+ }
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_matmul_id_f16_f32;
+ }
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+ return ctx->device->pipeline_matmul_id_f16;
+ }
+
+ GGML_ASSERT(src1_type == GGML_TYPE_F32);
+
+ switch (src0_type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ4_NL:
+ break;
+ default:
+ return nullptr;
+ }
+
+ return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
+}
+
+static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
+ VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
+ GGML_ASSERT(b_type == GGML_TYPE_F32);
+
+ switch (a_type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ4_NL:
+ break;
+ default:
+ return nullptr;
+ }
+
+ return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
+}
+
+static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
+ VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")");
+ VK_LOG_MEMORY("ggml_vk_pool_malloc");
+
+ int best_i = -1;
+ size_t best_size = std::numeric_limits<size_t>::max(); //smallest unused buffer that fits our needs
+ int worst_i = -1;
+ size_t worst_size = 0; //largest unused buffer seen so far
+ for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
+ vk_buffer &b = ctx->buffer_pool[i];
+ if (b != nullptr && b->size >= size && b->size < best_size) {
+ best_i = i;
+ best_size = b->size;
+ }
+ if (b != nullptr && b->size > worst_size) {
+ worst_i = i;
+ worst_size = b->size;
+ }
+ }
+ if(best_i != -1) {
+ //found the smallest buffer that fits our needs
+ vk_buffer b = ctx->buffer_pool[best_i];
+ ctx->buffer_pool[best_i].reset();
+ return b;
+ }
+ if(worst_i != -1) {
+ //no buffer that fits our needs, resize largest one to save memory
+ vk_buffer& b = ctx->buffer_pool[worst_i];
+ ggml_vk_destroy_buffer(b);
+ }
+
+ return ggml_vk_create_buffer_device(ctx->device, size);
+}
+
+static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) {
+ VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")");
+ for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
+ vk_buffer& b = ctx->buffer_pool[i];
+ if (b == nullptr) {
+ b = buffer;
+ return;
+ }
+ }
+ std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl;
+ ggml_vk_destroy_buffer(buffer);
+}
+
+// Returns an available temporary buffer that may only be used temporarily, it will be reused
+static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) {
+ // Try to find existing temp buffer with enough capacity
+ for (auto& buffer : ctx->gc.temp_buffers) {
+ if (buffer->size >= size) {
+ return buffer;
+ }
+ }
+
+ VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")");
+
+ // Otherwise create new buffer
+ vk_buffer buf = ggml_vk_pool_malloc(ctx, size);
+ ctx->gc.temp_buffers.push_back(buf);
+
+ return buf;
+}
+
+static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
+ VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
+ vk_buffer buf = ggml_vk_create_buffer(device, size,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
+
+ if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
+ fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
+ size/1024.0/1024.0);
+ device->device.freeMemory(buf->device_memory);
+ device->device.destroyBuffer(buf->buffer);
+ return nullptr;
+ }
+
+ device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
+
+ return buf->ptr;
+}
+
+static void ggml_vk_host_free(vk_device& device, void* ptr) {
+ if (ptr == nullptr) {
+ return;
+ }
+ VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
+ vk_buffer buf;
+ size_t index;
+ for (size_t i = 0; i < device->pinned_memory.size(); i++) {
+ const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
+ const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
+ if (ptr >= addr && ptr < endr) {
+ buf = std::get<2>(device->pinned_memory[i]);
+ index = i;
+ break;
+ }
+ }
+ if (buf == nullptr) {
+ fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n");
+ return;
+ }
+
+ ggml_vk_destroy_buffer(buf);
+
+ device->pinned_memory.erase(device->pinned_memory.begin() + index);
+}
+
+static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
+ buf = nullptr;
+ buf_offset = 0;
+ for (size_t i = 0; i < device->pinned_memory.size(); i++) {
+ const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
+ const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
+ if (ptr >= addr && ptr < endr) {
+ buf = std::get<2>(device->pinned_memory[i]);
+ buf_offset = ((const uint8_t *)ptr) - addr;
+ break;
+ }
+ }
+}
+
+static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bool one_time = true) {
+ vk_submission s;
+ s.buffer = ggml_vk_create_cmd_buffer(device, q);
+ if (one_time) {
+ s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
+ } else {
+ s.buffer.begin({ vk::CommandBufferUsageFlags{} });
+ }
+
+ return s;
+}
+
+static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, std::vector<vk_subbuffer>&& buffers, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
+ const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
+ const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
+ const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
+ VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {";
+ for (auto& buffer : buffers) {
+ std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.size << "), ";
+ }
+ std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
+ std::vector<vk::DescriptorBufferInfo> descriptor_buffer_infos;
+ std::vector<vk::WriteDescriptorSet> write_descriptor_sets;
+ GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size());
+ GGML_ASSERT(buffers.size() == pipeline->parameter_count);
+ vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++];
+ for (uint32_t i = 0; i < pipeline->parameter_count; i++) {
+ descriptor_buffer_infos.push_back({buffers[i].buffer->buffer, buffers[i].offset, buffers[i].size});
+ }
+ for (uint32_t i = 0; i < pipeline->parameter_count; i++) {
+ write_descriptor_sets.push_back({descriptor_set, i, 0, 1, vk::DescriptorType::eStorageBuffer, nullptr, &descriptor_buffer_infos[i]});
+ }
+
+ ctx->device->device.updateDescriptorSets(write_descriptor_sets, {});
+
+ subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants);
+ subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
+ subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
+ pipeline->layout,
+ 0,
+ { descriptor_set },
+ {});
+ subctx->s->buffer.dispatch(wg0, wg1, wg2);
+}
+
+static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
+ s.buffer.end();
+
+ s.wait_semaphores = std::move(wait_semaphores);
+ s.signal_semaphores = std::move(signal_semaphores);
+}
+
+static void ggml_vk_ctx_end(vk_context * ctx) {
+ VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")");
+ if (ctx->s == nullptr) {
+ return;
+ }
+
+ ctx->s->buffer.end();
+ ctx->s = nullptr;
+}
+
+static void ggml_vk_ctx_begin(vk_device& device, vk_context * subctx) {
+ VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")");
+ if (subctx->s != nullptr) {
+ ggml_vk_ctx_end(subctx);
+ }
+
+ subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->q) });
+ subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();
+}
+
+static size_t ggml_vk_align_size(size_t width, size_t align) {
+ VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")");
+ return CEIL_DIV(width, align) * align;
+}
+
+static void deferred_memcpy(void * dst, const void * src, size_t size, std::vector<vk_staging_memcpy>* memcpys = nullptr) {
+ if (memcpys == nullptr) {
+ memcpy(dst, src, size);
+ } else {
+ memcpys->emplace_back(dst, src, size);
+ }
+}
+
+static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) {
+ if (device->sync_staging == nullptr || device->sync_staging->size < size) {
+ VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
+ ggml_vk_destroy_buffer(device->sync_staging);
+ device->sync_staging = ggml_vk_create_buffer_check(device, size,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
+ }
+}
+
+static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context * subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) {
+ VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")");
+ GGML_ASSERT(!ggml_is_contiguous(tensor));
+ // Buffer is already mapped
+ if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
+ std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl;
+ GGML_ASSERT(false);
+ }
+ // Check if src is pinned memory
+ vk_buffer buf;
+ size_t buf_offset;
+ ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset);
+
+ const uint64_t ne0 = tensor->ne[0];
+ const uint64_t ne1 = tensor->ne[1];
+ const uint64_t ne2 = tensor->ne[2];
+ const uint64_t ne3 = tensor->ne[3];
+ const uint64_t nb0 = tensor->nb[0];
+ const uint64_t nb1 = tensor->nb[1];
+ const uint64_t nb2 = tensor->nb[2];
+ const uint64_t nb3 = tensor->nb[3];
+ const ggml_type type = tensor->type;
+ const uint64_t ts = ggml_type_size(type);
+ const uint64_t bs = ggml_blck_size(type);
+
+ const uint64_t dstnb0 = ts;
+ const uint64_t dstnb1 = dstnb0*(ne0/bs);
+ const uint64_t dstnb2 = dstnb1*ne1;
+ const uint64_t dstnb3 = dstnb2*ne2;
+
+ const uint64_t ne = ggml_nelements(tensor);
+
+ if (buf != nullptr) {
+ // Memory is pinned, use as staging buffer
+ std::vector<vk::BufferCopy> slices;
+
+ for (uint64_t i3 = 0; i3 < ne3; i3++) {
+ for (uint64_t i2 = 0; i2 < ne2; i2++) {
+ // Find longest contiguous slice
+ if (ne1*nb1 == dstnb2) {
+ slices.push_back({ buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 });
+ } else {
+ for (uint64_t i1 = 0; i1 < ne1; i1++) {
+ if (ne0*nb0/bs == dstnb1) {
+ slices.push_back({ buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 });
+ } else {
+ const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
+ const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
+ for (uint64_t i0 = 0; i0 < ne0; i0++) {
+ slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 });
+ }
+ }
+ }
+ }
+ }
+ }
+
+ ggml_vk_sync_buffers(subctx);
+ subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
+ return;
+ }
+
+ // Staging buffer required
+ vk_buffer staging = ctx->staging;
+ size_t staging_offset = ctx->staging_offset;
+ const size_t copy_size = ts*ne/bs;
+ if (ctx->staging->size < ctx->staging_offset + copy_size) {
+ if (sync_staging) {
+ // Create temporary larger buffer
+ ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size);
+
+ staging = ctx->device->sync_staging;
+ staging_offset = 0;
+ } else {
+ GGML_ASSERT(false);
+ }
+ }
+
+ VkBufferCopy buf_copy{ staging_offset, offset, copy_size };
+
+ ggml_vk_sync_buffers(subctx);
+ vkCmdCopyBuffer(subctx->s->buffer, staging->buffer, dst->buffer, 1, &buf_copy);
+
+ for (uint64_t i3 = 0; i3 < ne3; i3++) {
+ for (uint64_t i2 = 0; i2 < ne2; i2++) {
+ // Find longest contiguous slice
+ if (ne1*nb1 == dstnb2) {
+ deferred_memcpy((uint8_t *)staging->ptr + staging_offset + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys);
+ } else {
+ for (uint64_t i1 = 0; i1 < ne1; i1++) {
+ if (ne0*nb0/bs == dstnb1) {
+ deferred_memcpy((uint8_t *)staging->ptr + staging_offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys);
+ } else {
+ const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
+ const uint64_t d_off = staging_offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
+ for (uint64_t i0 = 0; i0 < ne0; i0++) {
+ deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys);
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+static void ggml_vk_buffer_write_2d_async(vk_context * subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+ VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
+ // Buffer is already mapped
+ if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
+ std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl;
+ GGML_ASSERT(false);
+ }
+ // Check if src is pinned memory
+ vk_buffer buf = nullptr;
+ size_t buf_offset;
+ ggml_vk_host_get(dst->device, src, buf, buf_offset);
+
+ if (buf != nullptr) {
+ // Memory is pinned, use as staging buffer
+ std::vector<vk::BufferCopy> slices(1);
+ if (width == spitch) {
+ // Only do single write if stride is equal
+ slices[0].srcOffset = buf_offset;
+ slices[0].dstOffset = offset;
+ slices[0].size = width * height;
+ } else {
+ slices.resize(height);
+ for (size_t i = 0; i < height; i++) {
+ slices[i].srcOffset = buf_offset + i * spitch;
+ slices[i].dstOffset = offset + i * width;
+ slices[i].size = width;
+ }
+ }
+
+ ggml_vk_sync_buffers(subctx);
+ subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
+ return;
+ }
+ VK_LOG_DEBUG("STAGING");
+
+ // Staging buffer required
+ const size_t copy_size = width*height;
+ if (staging_buffer == nullptr || staging_buffer->size < staging_offset + copy_size) {
+ if (sync_staging) {
+ ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size);
+
+ staging_buffer = dst->device->sync_staging;
+ staging_offset = 0;
+ } else {
+ GGML_ASSERT(false);
+ }
+ }
+
+ VkBufferCopy buf_copy = {
+ staging_offset,
+ offset,
+ copy_size};
+
+ ggml_vk_sync_buffers(subctx);
+ vkCmdCopyBuffer(subctx->s->buffer, staging_buffer->buffer, dst->buffer, 1, &buf_copy);
+
+ if (width == spitch) {
+ deferred_memcpy((uint8_t *)staging_buffer->ptr + staging_offset, src, width * height, &subctx->in_memcpys);
+ } else {
+ for (size_t i = 0; i < height; i++) {
+ deferred_memcpy((uint8_t *)staging_buffer->ptr + staging_offset + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
+ }
+ }
+}
+
+static void ggml_vk_buffer_write_async(vk_context * subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+ VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
+ return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, staging_buffer, staging_offset, sync_staging);
+}
+
+static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) {
+ VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")");
+ // Buffer is already mapped
+ if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
+ GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
+
+ for (size_t i = 0; i < height; i++) {
+ memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
+ }
+ } else {
+ vk_context * subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
+ ggml_vk_ctx_begin(dst->device, subctx);
+ ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, nullptr, 0, true);
+ ggml_vk_ctx_end(subctx);
+
+ for (auto& cpy : subctx->in_memcpys) {
+ memcpy(cpy.dst, cpy.src, cpy.n);
+ }
+
+ ggml_vk_submit(subctx, dst->device->fence);
+ VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
+ dst->device->device.resetFences({ dst->device->fence });
+
+ delete subctx;
+ }
+}
+
+static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) {
+ VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")");
+ ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1);
+}
+
+static void ggml_vk_buffer_read_2d_async(vk_context * subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+ VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")");
+ GGML_ASSERT(width > 0);
+ GGML_ASSERT(height > 0);
+ GGML_ASSERT(src != nullptr);
+
+ // Check if dst is pinned memory
+ vk_buffer buf = nullptr;
+ size_t buf_offset;
+ ggml_vk_host_get(src->device, dst, buf, buf_offset);
+
+ std::vector<vk::BufferCopy> slices(1);
+ if (width == spitch && width == dpitch) {
+ // Only do single write if stride is equal
+ slices[0].srcOffset = offset;
+ slices[0].dstOffset = buf_offset;
+ slices[0].size = width * height;
+ } else {
+ slices.resize(height);
+ for (size_t i = 0; i < height; i++) {
+ slices[i].srcOffset = offset + i * spitch;
+ slices[i].dstOffset = buf_offset + i * dpitch;
+ slices[i].size = width;
+ }
+ }
+
+ if (buf != nullptr) {
+ // Memory is pinned, use as staging buffer
+ ggml_vk_sync_buffers(subctx);
+ subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices);
+
+ return;
+ }
+ VK_LOG_DEBUG("STAGING");
+
+ // Fall back to staging buffer
+ const size_t copy_size = dpitch * height;
+ if (staging_buffer == nullptr || staging_buffer->size < staging_offset + copy_size) {
+ if (sync_staging) {
+ // Create temporary larger buffer
+ ggml_vk_ensure_sync_staging_buffer(src->device, copy_size);
+
+ staging_buffer = src->device->sync_staging;
+ } else {
+ GGML_ASSERT(false);
+ }
+ }
+
+ ggml_vk_sync_buffers(subctx);
+ subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
+
+ deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
+}
+
+static void ggml_vk_buffer_read_async(vk_context * subctx, vk_buffer& src, size_t offset, void * dst, size_t size, vk_buffer staging_buffer, size_t staging_offset, bool sync_staging = false) {
+ return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, staging_buffer, staging_offset, sync_staging);
+}
+
+static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
+ VK_LOG_DEBUG("ggml_vk_buffer_read(" << offset << ", " << size << ")");
+ if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
+ GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
+
+ memcpy(dst, (uint8_t *) src->ptr + offset, size);
+ } else {
+ vk_context * subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
+ ggml_vk_ctx_begin(src->device, subctx);
+ ggml_vk_buffer_read_async(subctx, src, offset, dst, size, nullptr, 0, true);
+ ggml_vk_ctx_end(subctx);
+
+ ggml_vk_submit(subctx, src->device->fence);
+ VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
+ src->device->device.resetFences({ src->device->fence });
+
+ for (auto& cpy : subctx->out_memcpys) {
+ memcpy(cpy.dst, cpy.src, cpy.n);
+ }
+
+ delete subctx;
+ }
+}
+
+static void ggml_vk_buffer_copy_async(vk_context * ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
+ VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")");
+ // Make sure both buffers are on same device
+ GGML_ASSERT(src->device == dst->device);
+
+ VkBufferCopy bc{ src_offset, dst_offset, size };
+
+ vkCmdCopyBuffer(ctx->s->buffer, src->buffer, dst->buffer, 1, &bc);
+}
+
+static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
+ if (src->device == dst->device) {
+ VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
+ // Copy within the device
+ vk_context * subctx = ggml_vk_create_temporary_context(src->device->transfer_queue);
+ ggml_vk_ctx_begin(src->device, subctx);
+ ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);
+ ggml_vk_ctx_end(subctx);
+ ggml_vk_submit(subctx, src->device->fence);
+ VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
+ src->device->device.resetFences({ src->device->fence });
+
+ delete subctx;
+ } else {
+ VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
+ // Copy device to device
+ ggml_vk_ensure_sync_staging_buffer(src->device, size);
+ ggml_vk_ensure_sync_staging_buffer(dst->device, size);
+
+ // Copy to src staging buffer
+ ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
+ // memcpy to dst staging buffer
+ memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size);
+ // Copy to dst buffer
+ ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size);
+ }
+}
+
+static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
+ VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
+
+ vk_context * subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue);
+ ggml_vk_ctx_begin(dst->device, subctx);
+ subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
+ ggml_vk_ctx_end(subctx);
+
+ ggml_vk_submit(subctx, dst->device->fence);
+ VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences");
+ dst->device->device.resetFences({ dst->device->fence });
+
+ delete subctx;
+}
+
+static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
+ VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
+ // if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
+ // return 4;
+ // }
+
+ return 1;
+
+ GGML_UNUSED(m); GGML_UNUSED(n); GGML_UNUSED(k);
+}
+
+static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
+ if (m <= 32 || n <= 32) {
+ return aligned ? mmp->a_s : mmp->s;
+ }
+ return aligned ? mmp->a_m : mmp->m;
+
+ GGML_UNUSED(ctx);
+}
+
+static vk_pipeline ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
+ return aligned ? mmp->a_m : mmp->m;
+
+ GGML_UNUSED(ctx);
+}
+
+static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
+ return aligned ? mmp->a_s : mmp->s;
+
+ GGML_UNUSED(ctx);
+}
+
+static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
+ switch (ctx->device->vendor_id) {
+ case VK_VENDOR_ID_AMD:
+ return ggml_vk_guess_matmul_pipeline_amd(ctx, mmp, m, n, aligned);
+ case VK_VENDOR_ID_APPLE:
+ return ggml_vk_guess_matmul_pipeline_apple(ctx, mmp, aligned);
+ case VK_VENDOR_ID_INTEL:
+ return ggml_vk_guess_matmul_pipeline_intel(ctx, mmp, aligned);
+ default:
+ break;
+ }
+
+ if (m <= 32 || n <= 32) {
+ return aligned ? mmp->a_s : mmp->s;
+ }
+ if (m <= 64 || n <= 64) {
+ return aligned ? mmp->a_m : mmp->m;
+ }
+ return aligned ? mmp->a_l : mmp->l;
+}
+
+static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
+ return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align;
+}
+
+static void ggml_vk_matmul(
+ ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline,
+ vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
+ uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
+ uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
+ uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) {
+ VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
+ ggml_vk_sync_buffers(subctx);
+ if (split_k == 1) {
+ const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 };
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
+ return;
+ }
+
+ GGML_ASSERT(batch_stride_d == m * n);
+
+ const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 };
+ // Make sure enough workgroups get assigned for split k to work
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
+ ggml_vk_sync_buffers(subctx);
+ const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
+}
+
+static void ggml_vk_matmul_id(
+ ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline,
+ vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
+ uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
+ uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
+ uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
+ VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
+ "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
+ "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
+ "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
+ ggml_vk_sync_buffers(subctx);
+ const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
+ nei0, nei1, nbi1, ne11 };
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
+}
+
+static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
+ return
+ tensor->nb[0] == ggml_type_size(tensor->type) &&
+ tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
+ tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
+}
+
+static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, ggml_type from, ggml_type to) {
+ if (from == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
+ return ctx->device->pipeline_cpy_f32_f32;
+ }
+ if (from == GGML_TYPE_F32 && to == GGML_TYPE_F16) {
+ return ctx->device->pipeline_cpy_f32_f16;
+ }
+ if (from == GGML_TYPE_F16 && to == GGML_TYPE_F16) {
+ return ctx->device->pipeline_cpy_f16_f16;
+ }
+
+ std::cerr << "Missing CPY op for types: " << ggml_type_name(from) << " " << ggml_type_name(to) << std::endl;
+ GGML_ASSERT(false);
+}
+
+static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) {
+ VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
+ std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")");
+ const int tensor_type_size = ggml_type_size(tensor->type);
+
+ const uint32_t ne = ggml_nelements(tensor);
+
+ const vk_op_unary_push_constants pc = {
+ (uint32_t)ne,
+ (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
+ (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
+ 0,
+ 0.0f, 0.0f,
+ };
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, { ne, 1, 1 });
+}
+
+static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
+ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
+
+ const uint64_t ne00 = src0->ne[0];
+ const uint64_t ne01 = src0->ne[1];
+ const uint64_t ne02 = src0->ne[2];
+ const uint64_t ne03 = src0->ne[3];
+
+ const uint64_t ne10 = src1->ne[0];
+ const uint64_t ne11 = src1->ne[1];
+ const uint64_t ne12 = src1->ne[2];
+ const uint64_t ne13 = src1->ne[3];
+
+ const uint64_t ne20 = dst->ne[0];
+ const uint64_t ne21 = dst->ne[1];
+
+ const uint64_t r2 = ne12 / ne02;
+ const uint64_t r3 = ne13 / ne03;
+
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
+ ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
+ ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
+
+ vk_buffer d_Qx;
+ size_t qx_buf_offset = 0;
+ vk_buffer d_Qy;
+ size_t qy_buf_offset = 0;
+
+ bool src0_uma = false;
+ bool src1_uma = false;
+
+ if (ctx->device->uma) {
+ ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
+ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
+ src0_uma = d_Qx != nullptr;
+ src1_uma = d_Qy != nullptr;
+ }
+
+ const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
+ const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
+
+ const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
+
+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
+
+ const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
+ const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
+
+ if (mmp == nullptr) {
+ // Fall back to dequant + f16 mulmat
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
+ }
+
+ // Not implemented
+ GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
+
+ const int x_ne = ne01 * ne00;
+ const int y_ne = ne11 * ne10;
+ const int d_ne = ne11 * ne01;
+
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
+ const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
+
+ const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
+
+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
+
+ const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
+ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
+ const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
+ const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
+ const uint64_t d_sz = sizeof(float) * d_ne;
+
+ vk_buffer d_D = extra->buffer_gpu.lock();
+ const uint64_t d_buf_offset = extra->offset + dst->view_offs;
+ GGML_ASSERT(d_D != nullptr);
+ GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03);
+ vk_buffer d_X;
+ uint64_t x_buf_offset = 0;
+ vk_buffer d_Y;
+ uint64_t y_buf_offset = 0;
+ if (!src0_uma) {
+ d_Qx = extra_src0->buffer_gpu.lock();
+ qx_buf_offset = extra_src0->offset + src0->view_offs;
+ GGML_ASSERT(d_Qx != nullptr);
+ }
+ if (!src1_uma) {
+ d_Qy = extra_src1->buffer_gpu.lock();
+ qy_buf_offset = extra_src1->offset + src1->view_offs;
+ GGML_ASSERT(d_Qy != nullptr);
+ }
+ if (qx_needs_dequant) {
+ d_X = ctx->prealloc_x;
+ GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03);
+ } else {
+ d_X = d_Qx;
+ x_buf_offset = qx_buf_offset;
+ GGML_ASSERT(qx_sz == x_sz);
+ }
+ if (qy_needs_dequant) {
+ d_Y = ctx->prealloc_y;
+ GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03);
+ } else {
+ d_Y = d_Qy;
+ y_buf_offset = qy_buf_offset;
+ GGML_ASSERT(qy_sz == y_sz);
+ }
+
+ vk_pipeline to_fp16_vk_0 = nullptr;
+ vk_pipeline to_fp16_vk_1 = nullptr;
+
+ if (x_non_contig) {
+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16);
+ } else {
+ to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
+ }
+ if (y_non_contig) {
+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16);
+ } else {
+ to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
+ }
+ GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
+ GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
+
+ // Allocate descriptor sets
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, 1);
+ if (qx_needs_dequant) {
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
+ }
+ if (qy_needs_dequant) {
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
+ }
+ if (split_k > 1) {
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
+ }
+
+ if (x_non_contig) {
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
+ } else if (qx_needs_dequant) {
+ const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { { d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, { d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
+ }
+ if (y_non_contig) {
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
+ }
+
+ uint32_t stride_batch_x = ne00*ne01;
+ uint32_t stride_batch_y = ne10*ne11;
+
+ if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
+ stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
+ }
+
+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
+ stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
+ }
+
+ // compute
+ ggml_vk_matmul(
+ ctx, subctx, pipeline,
+ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
+ { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
+ ne01, ne11, ne10,
+ ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
+ split_k, ne12*ne13, ne02, ne12, r2, r3
+ ); // NOLINT
+}
+
+static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
+ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
+
+ const uint64_t ne00 = src0->ne[0];
+ const uint64_t ne01 = src0->ne[1];
+ const uint64_t ne02 = src0->ne[2];
+ const uint64_t ne03 = src0->ne[3];
+
+ const uint64_t ne10 = src1->ne[0];
+ const uint64_t ne11 = src1->ne[1];
+ const uint64_t ne12 = src1->ne[2];
+ const uint64_t ne13 = src1->ne[3];
+
+ GGML_ASSERT(ne11 == 1);
+
+ const uint64_t ne20 = dst->ne[0];
+ const uint64_t ne21 = dst->ne[1];
+ const uint64_t ne22 = dst->ne[2];
+ const uint64_t ne23 = dst->ne[3];
+
+ const uint64_t r2 = ne12 / ne02;
+ const uint64_t r3 = ne13 / ne03;
+
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
+ ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
+ ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
+
+ vk_buffer d_Qx;
+ size_t qx_buf_offset = 0;
+ vk_buffer d_Qy;
+ size_t qy_buf_offset = 0;
+
+ bool src0_uma = false;
+ bool src1_uma = false;
+
+ if (ctx->device->uma) {
+ ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
+ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
+ src0_uma = d_Qx != nullptr;
+ src1_uma = d_Qy != nullptr;
+ }
+
+ const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
+ const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
+
+ const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
+
+ const bool qx_needs_dequant = x_non_contig;
+ const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
+
+ // Not implemented
+ GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
+
+ const uint64_t x_ne = ne01 * ne00;
+ const uint64_t y_ne = ne11 * ne10;
+ const uint64_t d_ne = ne11 * ne01;
+
+ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
+ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
+ const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
+ const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
+ const uint64_t d_sz = sizeof(float) * d_ne;
+
+ vk_buffer d_D = extra->buffer_gpu.lock();
+ const uint64_t d_buf_offset = extra->offset + dst->view_offs;
+ GGML_ASSERT(d_D != nullptr);
+ vk_buffer d_X;
+ uint64_t x_buf_offset = 0;
+ vk_buffer d_Y;
+ uint64_t y_buf_offset = 0;
+ if(!src0_uma) {
+ d_Qx = extra_src0->buffer_gpu.lock();
+ qx_buf_offset = extra_src0->offset + src0->view_offs;
+ GGML_ASSERT(d_Qx != nullptr);
+ }
+ if(!src1_uma) {
+ d_Qy = extra_src1->buffer_gpu.lock();
+ qy_buf_offset = extra_src1->offset + src1->view_offs;
+ GGML_ASSERT(d_Qy != nullptr);
+ }
+ if (qx_needs_dequant) {
+ d_X = ctx->prealloc_x;
+ } else {
+ d_X = d_Qx;
+ x_buf_offset = qx_buf_offset;
+ GGML_ASSERT(qx_sz == x_sz);
+ }
+ if (qy_needs_dequant) {
+ d_Y = ctx->prealloc_y;
+ } else {
+ d_Y = d_Qy;
+ y_buf_offset = qy_buf_offset;
+ GGML_ASSERT(qy_sz == y_sz);
+ }
+
+ vk_pipeline to_fp16_vk_0 = nullptr;
+ vk_pipeline to_fp16_vk_1 = nullptr;
+ if (x_non_contig) {
+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
+ }
+ if (y_non_contig) {
+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type);
+ } else {
+ to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
+ }
+ vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type);
+ GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
+ GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
+ GGML_ASSERT(dmmv != nullptr);
+
+ // Allocate descriptor sets
+ if (qx_needs_dequant) {
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
+ }
+ if (qy_needs_dequant) {
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
+ }
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, dmmv, ne12 * ne13);
+
+ if (x_non_contig) {
+ GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
+ }
+ if (y_non_contig) {
+ GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
+ }
+
+ uint32_t stride_batch_x = ne00*ne01;
+ uint32_t stride_batch_y = ne10*ne11;
+
+ if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
+ stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
+ }
+
+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
+ stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
+ }
+
+ const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
+
+ uint32_t groups_x = ne01;
+ uint32_t groups_z = 1;
+
+ if (ne01 > max_groups_x) {
+ groups_z = 64;
+ groups_x /= groups_z;
+ }
+
+ // compute
+ const vk_mat_vec_push_constants pc = {
+ (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
+ stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
+ (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
+ };
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
+ { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23} },
+ sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
+}
+
+static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
+ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
+ GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
+ GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT
+ GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ const uint64_t ne00 = src0->ne[0];
+ const uint64_t ne01 = src0->ne[1];
+ const uint64_t ne02 = src0->ne[2];
+ // const uint64_t ne03 = src0->ne[3];
+
+ const uint64_t ne10 = src1->ne[0];
+ const uint64_t ne11 = src1->ne[1];
+ const uint64_t ne12 = src1->ne[2];
+ // const uint64_t ne13 = src1->ne[3];
+
+ GGML_ASSERT(ne11 == 1);
+
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
+ ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
+ ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
+
+ vk_buffer d_Qy;
+ size_t qy_buf_offset = 0;
+
+ bool src1_uma = false;
+
+ if (ctx->device->uma) {
+ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
+ src1_uma = d_Qy != nullptr;
+ }
+
+ const uint64_t x_ne = ne00 * ne01 * ne02;
+ const uint64_t y_ne = ne10 * ne11 * ne12;
+ const uint64_t d_ne = ne01 * ne11 * ne12;
+
+ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
+ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
+ const uint64_t d_sz = sizeof(float) * d_ne;
+
+ vk_buffer d_D = extra->buffer_gpu.lock();
+ const uint64_t d_buf_offset = extra->offset + dst->view_offs;
+ GGML_ASSERT(d_D != nullptr);
+ vk_buffer d_Qx = extra_src0->buffer_gpu.lock();
+ const uint64_t qx_buf_offset = extra_src0->offset + src0->view_offs;
+ GGML_ASSERT(d_Qx != nullptr);
+ if (!src1_uma) {
+ d_Qy = extra_src1->buffer_gpu.lock();
+ qy_buf_offset = extra_src1->offset + src1->view_offs;
+ GGML_ASSERT(d_Qx != nullptr);
+ }
+
+ // Allocate descriptor sets
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
+
+ const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
+ const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
+
+ const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
+ const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
+
+ // compute
+ const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
+}
+
+static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
+ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
+ GGML_ASSERT(!ggml_is_transposed(src0));
+ GGML_ASSERT(!ggml_is_transposed(src1));
+ GGML_ASSERT(!ggml_is_permuted(src0));
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ const uint64_t ne00 = src0->ne[0];
+ const uint64_t ne01 = src0->ne[1];
+ const uint64_t ne02 = src0->ne[2];
+ // const uint64_t ne03 = src0->ne[3];
+
+ const uint64_t nb01 = src0->nb[1];
+ const uint64_t nb02 = src0->nb[2];
+
+ // const uint64_t ne10 = src1->ne[0];
+ const uint64_t ne11 = src1->ne[1];
+ const uint64_t ne12 = src1->ne[2];
+ // const uint64_t ne13 = src1->ne[3];
+
+ GGML_ASSERT(ne11 == 1);
+
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
+ ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
+ ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
+
+ vk_buffer d_Qy = nullptr;
+ size_t qy_buf_offset = 0;
+
+ bool src1_uma = false;
+
+ if (ctx->device->uma) {
+ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
+ src1_uma = d_Qy != nullptr;
+ }
+
+ const uint64_t d_ne = ne01 * ne11 * ne12;
+
+ const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
+ const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
+
+ const uint64_t qx_sz = ggml_nbytes(src0);
+ const uint64_t qy_sz = ggml_nbytes(src1);
+ const uint64_t d_sz = sizeof(float) * d_ne;
+
+ vk_buffer d_D = extra->buffer_gpu.lock();
+ const uint64_t d_buf_offset = extra->offset + dst->view_offs;
+ GGML_ASSERT(d_D != nullptr);
+ vk_buffer d_Qx = extra_src0->buffer_gpu.lock();
+ const uint64_t qx_buf_offset = extra_src0->offset + src0->view_offs;
+ GGML_ASSERT(d_Qx != nullptr);
+ if (!src1_uma) {
+ d_Qy = extra_src1->buffer_gpu.lock();
+ qy_buf_offset = extra_src1->offset + src1->view_offs;
+ GGML_ASSERT(d_Qx != nullptr);
+ }
+
+ // Allocate descriptor sets
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
+
+ const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
+ const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
+
+ const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
+ const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
+
+ // compute
+ const std::array<uint32_t, 7> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
+}
+
+static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
+ if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1) {
+ ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst);
+ } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1) {
+ ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst);
+ } else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
+ ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst);
+ } else {
+ ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst);
+ }
+}
+
+static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+ VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
+ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
+ std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
+
+ const uint64_t ne00 = src0->ne[0];
+ const uint64_t ne01 = src0->ne[1];
+ const uint64_t ne02 = src0->ne[2];
+ const uint64_t ne03 = src0->ne[3];
+
+ const uint64_t ne10 = src1->ne[0];
+ const uint64_t ne11 = src1->ne[1];
+ const uint64_t ne12 = src1->ne[2];
+ const uint64_t ne13 = src1->ne[3];
+
+ const uint64_t nei0 = ids->ne[0];
+ const uint64_t nei1 = ids->ne[1];
+ GGML_ASSERT(nei0 * nei1 <= 3072);
+
+ const uint32_t nbi1 = ids->nb[1];
+ const uint32_t nbi2 = ids->nb[2];
+
+ const uint64_t ne20 = dst->ne[0];
+ const uint64_t ne21 = dst->ne[1];
+ const uint64_t ne22 = dst->ne[2];
+ const uint64_t ne23 = dst->ne[3];
+
+ const uint64_t n_as = ne02;
+
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
+ ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
+ ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
+ ggml_tensor_extra_gpu * extra_ids = (ggml_tensor_extra_gpu *) ids->extra;
+
+ vk_buffer d_Qx;
+ size_t qx_buf_offset = 0;
+ vk_buffer d_Qy;
+ size_t qy_buf_offset = 0;
+ vk_buffer d_ids;
+ size_t ids_buf_offset = 0;
+
+ bool src0_uma = false;
+ bool src1_uma = false;
+ bool ids_uma = false;
+
+ if (ctx->device->uma) {
+ ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
+ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
+ ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
+ src0_uma = d_Qx != nullptr;
+ src1_uma = d_Qy != nullptr;
+ ids_uma = d_ids != nullptr;
+ }
+
+ const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
+ const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
+
+ const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
+
+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
+
+ const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
+ const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
+
+ if (mmp == nullptr) {
+ GGML_ASSERT(false);
+ }
+
+ // Not implemented
+ GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
+
+ const uint64_t x_ne = ne01 * ne00;
+ const uint64_t y_ne = ne11 * ne10;
+ const uint64_t d_ne = ne21 * ne20;
+
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, nei1));
+ const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
+
+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, nei1, aligned);
+
+ const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
+ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
+ const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
+ const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
+ const uint64_t ids_sz = nbi2;
+ const uint64_t d_sz = sizeof(float) * d_ne;
+
+ vk_buffer d_D = extra->buffer_gpu.lock();
+ const uint64_t d_buf_offset = extra->offset + dst->view_offs;
+ GGML_ASSERT(d_D != nullptr);
+ vk_buffer d_X;
+ uint64_t x_buf_offset = 0;
+ vk_buffer d_Y;
+ uint64_t y_buf_offset = 0;
+ if (!src0_uma) {
+ d_Qx = extra_src0->buffer_gpu.lock();
+ qx_buf_offset = extra_src0->offset + src0->view_offs;
+ GGML_ASSERT(d_Qx != nullptr);
+ }
+ if (!src1_uma) {
+ d_Qy = extra_src1->buffer_gpu.lock();
+ qy_buf_offset = extra_src1->offset + src1->view_offs;
+ GGML_ASSERT(d_Qy != nullptr);
+ }
+ if (!ids_uma) {
+ d_ids = extra_ids->buffer_gpu.lock();
+ ids_buf_offset = extra_ids->offset + ids->view_offs;
+ GGML_ASSERT(d_ids != nullptr);
+ }
+ if (qx_needs_dequant) {
+ d_X = ctx->prealloc_x;
+ GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03);
+ } else {
+ d_X = d_Qx;
+ x_buf_offset = qx_buf_offset;
+ GGML_ASSERT(qx_sz == x_sz);
+ }
+ if (qy_needs_dequant) {
+ d_Y = ctx->prealloc_y;
+ GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03);
+ } else {
+ d_Y = d_Qy;
+ y_buf_offset = qy_buf_offset;
+ GGML_ASSERT(qy_sz == y_sz);
+ }
+
+ vk_pipeline to_fp16_vk_0 = nullptr;
+ vk_pipeline to_fp16_vk_1 = nullptr;
+
+ if (x_non_contig) {
+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16);
+ } else {
+ to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
+ }
+ if (y_non_contig) {
+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16);
+ } else {
+ to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
+ }
+ GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
+ GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
+
+ // Allocate descriptor sets
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, 1);
+ if (qx_needs_dequant) {
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
+ }
+ if (qy_needs_dequant) {
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
+ }
+
+ if (x_non_contig) {
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
+ } else if (qx_needs_dequant) {
+ const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { { d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, { d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
+ }
+ if (y_non_contig) {
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
+ }
+
+ uint32_t stride_batch_x = ne00*ne01;
+ uint32_t stride_batch_y = ne10*ne11;
+
+ if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
+ stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
+ }
+
+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
+ stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
+ }
+
+ // compute
+ ggml_vk_matmul_id(
+ ctx, subctx, pipeline,
+ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
+ { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
+ ne01, ne21, ne10, ne10, ne10, ne01,
+ stride_batch_x, stride_batch_y, ne20*ne21,
+ n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11
+ ); // NOLINT
+}
+
+static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+ VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
+ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
+ std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT
+ GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
+
+ const uint64_t ne00 = src0->ne[0];
+ const uint64_t ne01 = src0->ne[1];
+ const uint64_t ne02 = src0->ne[2];
+ const uint64_t ne03 = src0->ne[3];
+
+ const uint64_t ne10 = src1->ne[0];
+ const uint64_t ne11 = src1->ne[1];
+ const uint64_t ne12 = src1->ne[2];
+ const uint64_t ne13 = src1->ne[3];
+
+ const uint64_t nei0 = ids->ne[0];
+ const uint64_t nei1 = ids->ne[1];
+
+ const uint64_t nbi2 = ids->nb[2];
+
+ GGML_ASSERT(nei1 == 1);
+
+ const uint64_t ne20 = dst->ne[0];
+ const uint64_t ne21 = dst->ne[1];
+ const uint64_t ne22 = dst->ne[2];
+ const uint64_t ne23 = dst->ne[3];
+
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
+ ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
+ ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra;
+ ggml_tensor_extra_gpu * extra_ids = (ggml_tensor_extra_gpu *) ids->extra;
+
+ vk_buffer d_Qx;
+ size_t qx_buf_offset = 0;
+ vk_buffer d_Qy;
+ size_t qy_buf_offset = 0;
+ vk_buffer d_ids;
+ size_t ids_buf_offset = 0;
+
+ bool src0_uma = false;
+ bool src1_uma = false;
+ bool ids_uma = false;
+
+ if (ctx->device->uma) {
+ ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
+ ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
+ ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
+ src0_uma = d_Qx != nullptr;
+ src1_uma = d_Qy != nullptr;
+ ids_uma = d_ids != nullptr;
+ }
+
+ const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
+ const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
+
+ const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
+
+ const bool qx_needs_dequant = x_non_contig;
+ const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
+
+ // Not implemented
+ GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
+
+ const uint64_t x_ne = ne01 * ne00;
+ const uint64_t y_ne = ne11 * ne10;
+ const uint64_t d_ne = ne21 * ne20;
+
+ const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
+ const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
+ const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
+ const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
+ const uint64_t ids_sz = nbi2;
+ const uint64_t d_sz = sizeof(float) * d_ne;
+
+ vk_buffer d_D = extra->buffer_gpu.lock();
+ const uint64_t d_buf_offset = extra->offset + dst->view_offs;
+ GGML_ASSERT(d_D != nullptr);
+ vk_buffer d_X;
+ uint64_t x_buf_offset = 0;
+ vk_buffer d_Y;
+ uint64_t y_buf_offset = 0;
+ if(!src0_uma) {
+ d_Qx = extra_src0->buffer_gpu.lock();
+ qx_buf_offset = extra_src0->offset + src0->view_offs;
+ GGML_ASSERT(d_Qx != nullptr);
+ }
+ if(!src1_uma) {
+ d_Qy = extra_src1->buffer_gpu.lock();
+ qy_buf_offset = extra_src1->offset + src1->view_offs;
+ GGML_ASSERT(d_Qy != nullptr);
+ }
+ if(!ids_uma) {
+ d_ids = extra_ids->buffer_gpu.lock();
+ ids_buf_offset = extra_ids->offset + ids->view_offs;
+ GGML_ASSERT(d_ids != nullptr);
+ }
+ if (qx_needs_dequant) {
+ d_X = ctx->prealloc_x;
+ } else {
+ d_X = d_Qx;
+ x_buf_offset = qx_buf_offset;
+ GGML_ASSERT(qx_sz == x_sz);
+ }
+ if (qy_needs_dequant) {
+ d_Y = ctx->prealloc_y;
+ } else {
+ d_Y = d_Qy;
+ y_buf_offset = qy_buf_offset;
+ GGML_ASSERT(qy_sz == y_sz);
+ }
+
+ vk_pipeline to_fp16_vk_0 = nullptr;
+ vk_pipeline to_fp16_vk_1 = nullptr;
+ if (x_non_contig) {
+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
+ }
+ if (y_non_contig) {
+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type);
+ } else {
+ to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
+ }
+ vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type);
+ GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
+ GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
+ GGML_ASSERT(dmmv != nullptr);
+
+ // Allocate descriptor sets
+ if (qx_needs_dequant) {
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
+ }
+ if (qy_needs_dequant) {
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
+ }
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, dmmv, ne12 * ne13);
+
+ if (x_non_contig) {
+ GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
+ }
+ if (y_non_contig) {
+ GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
+ }
+
+ uint32_t stride_batch_y = ne10*ne11;
+
+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
+ stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
+ }
+
+ const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
+
+ uint32_t groups_x = ne01;
+ uint32_t groups_z = 1;
+
+ if (ne01 > max_groups_x) {
+ groups_z = 64;
+ groups_x /= groups_z;
+ }
+
+ // compute
+ const vk_mat_vec_id_push_constants pc = {
+ (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
+ (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
+ (uint32_t)nei0, (uint32_t)ne11,
+ };
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
+ { { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne22 * ne23}, { d_ids, ids_buf_offset, ids_sz } },
+ sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z });
+}
+
+static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+ VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")");
+ if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
+ ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst);
+ } else {
+ ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst);
+ }
+}
+
+static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ // guaranteed to be an integer due to the check in ggml_can_repeat
+ const uint64_t ne0 = dst->ne[0];
+ const uint64_t ne1 = dst->ne[1];
+ const uint64_t ne2 = dst->ne[2];
+ const uint64_t ne3 = dst->ne[3];
+
+ const uint64_t ne00 = src0->ne[0];
+ const uint64_t ne01 = src0->ne[1];
+ const uint64_t ne02 = src0->ne[2];
+ const uint64_t ne03 = src0->ne[3];
+
+ const uint64_t nb0 = dst->nb[0];
+ const uint64_t nb1 = dst->nb[1];
+ const uint64_t nb2 = dst->nb[2];
+ const uint64_t nb3 = dst->nb[3];
+
+ const uint64_t nb00 = src0->nb[0];
+ const uint64_t nb01 = src0->nb[1];
+ const uint64_t nb02 = src0->nb[2];
+ const uint64_t nb03 = src0->nb[3];
+
+ const uint64_t nr0 = ne0/ne00;
+ const uint64_t nr1 = ne1/ne01;
+ const uint64_t nr2 = ne2/ne02;
+ const uint64_t nr3 = ne3/ne03;
+
+ // TODO: support for transposed / permuted tensors
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
+ ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
+
+ const vk_buffer src_buf = extra_src0->buffer_gpu.lock();
+ const uint64_t src_offset = extra_src0->offset + src0->view_offs;
+ vk_buffer dst_buf = extra->buffer_gpu.lock();
+ const uint64_t dst_offset = extra->offset + dst->view_offs;
+
+ std::vector<vk::BufferCopy> copies;
+
+ for (uint64_t i3 = 0; i3 < nr3; i3++) {
+ for (uint64_t k3 = 0; k3 < ne03; k3++) {
+ for (uint64_t i2 = 0; i2 < nr2; i2++) {
+ for (uint64_t k2 = 0; k2 < ne02; k2++) {
+ for (uint64_t i1 = 0; i1 < nr1; i1++) {
+ for (uint64_t k1 = 0; k1 < ne01; k1++) {
+ for (uint64_t i0 = 0; i0 < nr0; i0++) {
+ copies.push_back({
+ src_offset + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0,
+ dst_offset + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01,
+ ne00*nb0,
+ });
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ ggml_vk_sync_buffers(subctx);
+ subctx->s->buffer.copyBuffer(src_buf->buffer, dst_buf->buffer, copies);
+
+ GGML_UNUSED(ctx);
+ GGML_UNUSED(src1);
+}
+
+
+static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
+ switch (op) {
+ case GGML_OP_ADD:
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_add_f32;
+ }
+ return nullptr;
+ case GGML_OP_GET_ROWS:
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
+ if (dst->type == GGML_TYPE_F16) {
+ return ctx->device->pipeline_get_rows[src0->type];
+ }
+ if (dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_get_rows_f32[src0->type];
+ }
+ return nullptr;
+ case GGML_OP_MUL:
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_mul_f32;
+ }
+ return nullptr;
+ case GGML_OP_DIV:
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_div_f32;
+ }
+ return nullptr;
+ case GGML_OP_SCALE:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_scale_f32;
+ }
+ return nullptr;
+ case GGML_OP_SQR:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_sqr_f32;
+ }
+ return nullptr;
+ case GGML_OP_CLAMP:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_clamp_f32;
+ }
+ return nullptr;
+ case GGML_OP_CPY:
+ case GGML_OP_CONT:
+ case GGML_OP_DUP:
+ return ggml_vk_get_cpy_pipeline(ctx, src0->type, dst->type);
+ case GGML_OP_NORM:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_norm_f32;
+ }
+ return nullptr;
+ case GGML_OP_RMS_NORM:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_rms_norm_f32;
+ }
+ return nullptr;
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(dst)) {
+ case GGML_UNARY_OP_SILU:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_silu_f32;
+ }
+ break;
+ case GGML_UNARY_OP_GELU:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_gelu_f32;
+ }
+ break;
+ case GGML_UNARY_OP_RELU:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_relu_f32;
+ }
+ break;
+ default:
+ break;
+ }
+ return nullptr;
+ case GGML_OP_DIAG_MASK_INF:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_diag_mask_inf_f32;
+ }
+ return nullptr;
+ case GGML_OP_SOFT_MAX:
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
+
+ if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_soft_max_f32;
+ }
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_soft_max_f32_f16;
+ }
+ return nullptr;
+ case GGML_OP_ROPE:
+ {
+ const int mode = ((const int32_t *) dst->op_params)[2];
+ const bool is_neox = mode & 2;
+
+ if (is_neox) {
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_rope_neox_f32;
+ }
+ if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+ return ctx->device->pipeline_rope_neox_f16;
+ }
+ } else {
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_rope_norm_f32;
+ }
+ if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+ return ctx->device->pipeline_rope_norm_f16;
+ }
+ }
+ return nullptr;
+ }
+ case GGML_OP_ARGSORT:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
+ return ctx->device->pipeline_argsort_f32;
+ }
+ return nullptr;
+ case GGML_OP_SUM_ROWS:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_sum_rows_f32;
+ }
+ return nullptr;
+ default:
+ return nullptr;
+ }
+
+ GGML_UNUSED(src2);
+}
+
+static ggml_vk_func_t ggml_vk_op_get_func(ggml_op op) {
+ switch(op) {
+ case GGML_OP_REPEAT:
+ return ggml_vk_op_repeat;
+ default:
+ return nullptr;
+ }
+}
+
+static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
+ switch (op) {
+ case GGML_OP_CPY:
+ case GGML_OP_GET_ROWS:
+ case GGML_OP_ADD:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_SCALE:
+ case GGML_OP_SQR:
+ case GGML_OP_CLAMP:
+ return true;
+ default:
+ return false;
+ }
+}
+
+template<typename PC>
+static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) {
+ VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
+ if (src1 != nullptr) {
+ std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
+ }
+ if (src2 != nullptr) {
+ std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
+ }
+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "), " << ggml_op_name(op) << ")");
+ GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
+ GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
+ GGML_ASSERT(dst->extra != nullptr);
+ const uint64_t ne00 = src0->ne[0];
+ const uint64_t ne01 = src0->ne[1];
+ const uint64_t ne02 = src0->ne[2];
+ const uint64_t ne03 = src0->ne[3];
+ const uint64_t ne0 = ne00 * ne01;
+
+ const bool use_src1 = src1 != nullptr;
+ const uint64_t ne10 = use_src1 ? src1->ne[0] : 0;
+ const uint64_t ne11 = use_src1 ? src1->ne[1] : 0;
+ const uint64_t ne12 = use_src1 ? src1->ne[2] : 0;
+ const uint64_t ne13 = use_src1 ? src1->ne[3] : 0;
+ const uint64_t ne1 = ne10 * ne11;
+ // const uint64_t nb10 = use_src1 ? src1->nb[0] : 0;
+
+ const bool use_src2 = src2 != nullptr;
+ const uint64_t ne20 = use_src2 ? src2->ne[0] : 0;
+ const uint64_t ne21 = use_src2 ? src2->ne[1] : 0;
+ const uint64_t ne22 = use_src2 ? src2->ne[2] : 0;
+ const uint64_t ne23 = use_src2 ? src2->ne[3] : 0;
+ const uint64_t ne2 = ne20 * ne21;
+
+ const uint64_t ned0 = dst->ne[0];
+ const uint64_t ned1 = dst->ne[1];
+ const uint64_t ned2 = dst->ne[2];
+ const uint64_t ned3 = dst->ne[3];
+ const uint64_t ned = ned0 * ned1;
+
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
+ ggml_vk_func_t op_func;
+
+ if (pipeline == nullptr) {
+ op_func = ggml_vk_op_get_func(op);
+ if (op_func == nullptr) {
+ std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type);
+ if (src1 != nullptr) {
+ std::cerr << " and " << ggml_type_name(src1->type);
+ }
+ std::cerr << " to " << ggml_type_name(dst->type) << std::endl;
+ GGML_ASSERT(false);
+ }
+
+ op_func(ctx, subctx, src0, src1, dst);
+ return;
+ }
+
+ const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
+
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
+ ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
+ ggml_tensor_extra_gpu * extra_src1 = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
+ ggml_tensor_extra_gpu * extra_src2 = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr;
+
+ vk_buffer d_X = nullptr;
+ size_t x_buf_offset = 0;
+ vk_buffer d_Y = nullptr;
+ size_t y_buf_offset = 0;
+ vk_buffer d_Z = nullptr;
+ size_t z_buf_offset = 0;
+
+ bool src0_uma = false;
+ bool src1_uma = false;
+ bool src2_uma = false;
+
+ if (ctx->device->uma) {
+ ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset);
+ src0_uma = d_X != nullptr;
+ if (use_src1) {
+ ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset);
+ src1_uma = d_Y != nullptr;
+ }
+ if (use_src2) {
+ ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset);
+ src2_uma = d_Z != nullptr;
+ }
+ }
+
+ uint64_t x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0;
+ uint64_t y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 : 0;
+ uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0;
+ uint64_t d_sz = ggml_type_size(dst->type) * ned;
+
+ vk_buffer d_D = extra->buffer_gpu.lock();
+
+ // Workaround for tiny tensor inputs on ROPE
+ if (use_src1 && y_sz > d_D->size) {
+ y_sz = VK_WHOLE_SIZE;
+ }
+
+ GGML_ASSERT(d_D != nullptr);
+ uint64_t d_buf_offset = ((extra->offset + dst->view_offs) / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
+ GGML_ASSERT(d_buf_offset == extra->offset || op == GGML_OP_CPY); // NOLINT
+ if(!src0_uma) {
+ d_X = extra_src0->buffer_gpu.lock();
+ x_buf_offset = extra_src0->offset + src0->view_offs;
+ GGML_ASSERT(d_X != nullptr);
+ }
+ if (use_src1 && !src1_uma) {
+ d_Y = extra_src1->buffer_gpu.lock();
+ y_buf_offset = extra_src1->offset + src1->view_offs;
+ GGML_ASSERT(d_Y != nullptr);
+ }
+ if (use_src2 && !src2_uma) {
+ d_Z = extra_src2->buffer_gpu.lock();
+ z_buf_offset = extra_src2->offset + src2->view_offs;
+ GGML_ASSERT(d_Z != nullptr);
+ }
+
+ if (op_supports_incontiguous) {
+ x_sz = ggml_nbytes(src0);
+ y_sz = use_src1 ? ggml_nbytes(src1) : 0;
+ z_sz = use_src2 ? ggml_nbytes(src2) : 0;
+ d_sz = ggml_nbytes(dst);
+
+ if (x_buf_offset + x_sz >= d_X->size) {
+ x_sz = VK_WHOLE_SIZE;
+ }
+ if (use_src1 && y_buf_offset + y_sz >= d_Y->size) {
+ y_sz = VK_WHOLE_SIZE;
+ }
+ if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
+ z_sz = VK_WHOLE_SIZE;
+ }
+ if (d_buf_offset + d_sz >= d_D->size) {
+ d_sz = VK_WHOLE_SIZE;
+ }
+ }
+
+ std::array<uint32_t, 3> elements;
+
+ // Single call if dimension 2 is contiguous
+ if (op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))) {
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, 1);
+
+ switch (dst->op) {
+ case GGML_OP_NORM:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_SOFT_MAX:
+ case GGML_OP_SUM_ROWS:
+ elements = { (uint32_t)ggml_nrows(src0), 1, 1 };
+ break;
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_ROPE:
+ elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
+ break;
+ case GGML_OP_GET_ROWS:
+ elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
+ break;
+ case GGML_OP_ARGSORT:
+ elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
+ break;
+ default:
+ elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
+ break;
+ }
+
+ if (!op_supports_incontiguous) {
+ if (x_sz != VK_WHOLE_SIZE) {
+ x_sz *= ne02 * ne03;
+ }
+ if (use_src1 && y_sz != VK_WHOLE_SIZE) {
+ y_sz *= ne12 * ne13;
+ }
+ if (use_src2 && z_sz != VK_WHOLE_SIZE) {
+ z_sz *= ne22 * ne23;
+ }
+ if (d_sz != VK_WHOLE_SIZE) {
+ d_sz *= ned2 * ned3;
+ }
+ }
+
+ if (op == GGML_OP_SOFT_MAX) {
+ // Empty src1 is possible in soft_max, but the shader needs a buffer
+ vk_subbuffer subbuf_y;
+ if (use_src1) {
+ subbuf_y = { d_Y, y_buf_offset, y_sz };
+ } else {
+ subbuf_y = { d_X, 0, d_X->size };
+ }
+
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, subbuf_y, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+ } else if (op == GGML_OP_ROPE) {
+ // Empty src2 is possible in rope, but the shader needs a buffer
+ vk_subbuffer subbuf_z;
+ if (use_src2) {
+ subbuf_z = { d_Z, z_buf_offset, z_sz };
+ } else {
+ subbuf_z = { d_X, 0, d_X->size };
+ }
+
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+ } else if (use_src2) {
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_Z, z_buf_offset, z_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+ } else if (use_src1) {
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+ } else {
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+ }
+ } else {
+ GGML_ASSERT(op != GGML_OP_SOFT_MAX);
+ GGML_ASSERT(op != GGML_OP_ARGSORT);
+ GGML_ASSERT(!use_src2);
+
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, pipeline, ne02 * ne03);
+
+ switch (dst->op) {
+ case GGML_OP_NORM:
+ case GGML_OP_RMS_NORM:
+ elements = { (uint32_t)ne01, 1, 1 };
+ break;
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_ROPE:
+ elements = { (uint32_t)ne01, (uint32_t)ne00, 1 };
+ break;
+ case GGML_OP_GET_ROWS:
+ elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
+ break;
+ default:
+ elements = { (uint32_t)ne0, 1, 1 };
+ break;
+ }
+
+ for (uint64_t i03 = 0; i03 < ne03; i03++) {
+ for (uint64_t i02 = 0; i02 < ne02; i02++) {
+ const uint32_t it_idx0 = (i03 * ne02 + i02);
+ const uint32_t it_idx1 = use_src1 ? ((i03 % ne13) * ne12 + (i02 % ne12)) : 0;
+ const uint32_t x_offset = x_sz * it_idx0;
+ const uint32_t y_offset = y_sz * it_idx1;
+ const uint32_t d_offset = d_sz * it_idx0;
+
+ if (use_src1) {
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_Y, y_buf_offset + y_offset, y_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
+ } else {
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
+ }
+ }
+ }
+ }
+}
+
+static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
+}
+
+static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f,
+ });
+}
+
+static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f,
+ });
+}
+
+static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f,
+ });
+}
+
+static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f,
+ });
+}
+
+static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ float * op_params = (float *)dst->op_params;
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ op_params[0], 0.0f
+ });
+}
+
+static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f,
+ });
+}
+
+static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ float * op_params = (float *)dst->op_params;
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ op_params[0], op_params[1],
+ });
+}
+
+static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+ const uint32_t d_offset = ((extra->offset + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
+
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ d_offset,
+ 0.0f, 0.0f,
+ });
+}
+
+static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ float * op_params = (float *)dst->op_params;
+
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
+}
+
+static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ float * op_params = (float *)dst->op_params;
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
+}
+
+static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
+}
+
+static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ int32_t * op_params = (int32_t *)dst->op_params;
+ ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
+}
+
+static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ float * op_params = (float *)dst->op_params;
+
+ float scale = op_params[0];
+ float max_bias = op_params[1];
+
+ const uint32_t ncols = (uint32_t)src0->ne[0];
+ const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
+ const uint32_t nrows_y = (uint32_t)src0->ne[1];
+
+ const uint32_t n_head_kv = nrows_x/nrows_y;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
+ ncols,
+ src1 != nullptr ? nrows_y : (uint32_t)0,
+ scale, max_bias,
+ m0, m1,
+ n_head_log2,
+ });
+}
+
+static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
+ const int n_dims = ((int32_t *) dst->op_params)[1];
+ // const int mode = ((int32_t *) dst->op_params)[2];
+ // const int n_ctx = ((int32_t *) dst->op_params)[3];
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+ const float freq_base = ((float *) dst->op_params)[5];
+ const float freq_scale = ((float *) dst->op_params)[6];
+ const float ext_factor = ((float *) dst->op_params)[7];
+ const float attn_factor = ((float *) dst->op_params)[8];
+ const float beta_fast = ((float *) dst->op_params)[9];
+ const float beta_slow = ((float *) dst->op_params)[10];
+
+ float corr_dims[2];
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+ ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
+ (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
+ freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
+ src2 != nullptr,
+ });
+}
+
+static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ int32_t * op_params = (int32_t *)dst->op_params;
+
+ uint32_t ncols = src0->ne[0];
+
+ uint32_t ncols_pad = 1;
+ while (ncols_pad < ncols) {
+ ncols_pad *= 2;
+ }
+
+ GGML_ASSERT(ncols_pad <= 1024);
+
+ ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
+ ncols,
+ ncols_pad,
+ op_params[0],
+ });
+}
+
+static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f });
+}
+
+#ifdef GGML_VULKAN_RUN_TESTS
+static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) {
+ if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) {
+ return;
+ }
+ i0 = std::max(i0, 5);
+ i1 = std::max(i1, 5);
+ i2 = std::max(i2, 0);
+ fprintf(stderr, " ");
+ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
+ fprintf(stderr, "%7d ", idx1);
+ }
+ fprintf(stderr, "\n");
+ for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
+ fprintf(stderr, "%7d: ", idx0);
+ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
+ if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) {
+ float val;
+ if (type == GGML_TYPE_F32) {
+ val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0);
+ } else if (type == GGML_TYPE_F16) {
+ val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0));
+ } else {
+ GGML_ASSERT(false);
+ }
+ fprintf(stderr, "% 7.2f ", val);
+ } else {
+ fprintf(stderr, " ");
+ }
+ }
+ fprintf(stderr, "\n");
+ }
+}
+
+template <typename X_TYPE, typename Y_TYPE>
+static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) {
+ VK_LOG_DEBUG("ggml_vk_test_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << shader_size << ")");
+ const size_t x_ne = m * k * batch;
+ const size_t y_ne = k * n * batch;
+ const size_t d_ne = m * n * batch;
+
+ vk_pipeline p;
+ std::string shname;
+ if (shader_size == 0) {
+ if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f32->a_s;
+ shname = "F32_ALIGNED_S";
+ } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f32_f16->a_s;
+ shname = "F32_F16_ALIGNED_S";
+ } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f16_f32->a_s;
+ shname = "F16_F32_ALIGNED_S";
+ } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f16->a_s;
+ shname = "F16_ALIGNED_S";
+ } else {
+ GGML_ASSERT(false);
+ }
+ } else if (shader_size == 1) {
+ if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f32->a_m;
+ shname = "F32_ALIGNED_M";
+ } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f32_f16->a_m;
+ shname = "F32_F16_ALIGNED_M";
+ } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f16_f32->a_m;
+ shname = "F16_F32_ALIGNED_M";
+ } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f16->a_m;
+ shname = "F16_ALIGNED_M";
+ } else {
+ GGML_ASSERT(false);
+ }
+ } else if (shader_size == 2) {
+ if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f32->a_l;
+ shname = "F32_ALIGNED_L";
+ } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f32_f16->a_l;
+ shname = "F32_F16_ALIGNED_L";
+ } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f16_f32->a_l;
+ shname = "F16_F32_ALIGNED_L";
+ } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f16->a_l;
+ shname = "F16_ALIGNED_L";
+ } else {
+ GGML_ASSERT(false);
+ }
+ } else {
+ GGML_ASSERT(0);
+ }
+
+ const size_t kpad = ggml_vk_align_size(k, p->align);
+
+ if (k != kpad) {
+ if (shader_size == 0) {
+ if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f32->s;
+ shname = "F32_S";
+ } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f32_f16->s;
+ shname = "F32_F16_S";
+ } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f16_f32->s;
+ shname = "F16_F32_S";
+ } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f16->s;
+ shname = "F16_S";
+ }
+ } else if (shader_size == 1) {
+ if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f32->m;
+ shname = "F32_M";
+ } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f32_f16->m;
+ shname = "F32_F16_M";
+ } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f16_f32->m;
+ shname = "F16_F32_M";
+ } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f16->m;
+ shname = "F16_M";
+ }
+ } else if (shader_size == 2) {
+ if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f32->l;
+ shname = "F32_L";
+ } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f32_f16->l;
+ shname = "F32_F16_L";
+ } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f16_f32->l;
+ shname = "F16_F32_L";
+ } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
+ p = ctx->device->pipeline_matmul_f16->l;
+ shname = "F16_L";
+ }
+ }
+ }
+
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, p, num_it);
+ if (split_k > 1) {
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it);
+
+ if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
+ // Resize buffer
+ if (ctx->prealloc_split_k != nullptr) {
+ ggml_vk_destroy_buffer(ctx->prealloc_split_k);
+ }
+ ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ }
+ }
+
+ vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
+
+ X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne);
+ Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne);
+ float* d = (float *) malloc(sizeof(float) * d_ne);
+
+ for (size_t i = 0; i < x_ne; i++) {
+ if (std::is_same<float, X_TYPE>()) {
+ x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
+ } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
+ x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
+ } else {
+ GGML_ASSERT(false);
+ }
+ }
+ for (size_t i = 0; i < y_ne; i++) {
+ if (std::is_same<float, Y_TYPE>()) {
+ // y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
+ y[i] = (i % k == i / k) ? 1.0f : 0.0f;
+ } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
+ // y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
+ y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
+ } else {
+ GGML_ASSERT(false);
+ }
+ }
+
+ ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch);
+ ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
+
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+ for (size_t i = 0; i < num_it; i++) {
+ ggml_vk_ctx_begin(ctx->device, subctx);
+ ggml_vk_matmul(
+ ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
+ m, n, k,
+ k, k, m, k*m, k*n, m*n,
+ split_k, batch, batch, batch, 1, 1
+ );
+ ggml_vk_ctx_end(subctx);
+ }
+
+ auto begin = std::chrono::high_resolution_clock::now();
+ ggml_vk_submit(subctx, ctx->fence);
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
+ ctx->device->device.resetFences({ ctx->fence });
+
+ auto end = std::chrono::high_resolution_clock::now();
+ double time = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
+
+ // copy dst to host
+ ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne);
+
+ float * d_chk = (float *) malloc(sizeof(float) * d_ne);
+
+ ggml_init_params iparams = {
+ /*.mem_size =*/ 1024*1024*1024,
+ /*.mem_buffer =*/ NULL,
+ /*.no_alloc =*/ true,
+ };
+
+ ggml_context * ggml_ctx = ggml_init(iparams);
+
+ ggml_type src0_type;
+ ggml_type src1_type;
+
+ if (std::is_same<float, X_TYPE>()) {
+ src0_type = GGML_TYPE_F32;
+ } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
+ src0_type = GGML_TYPE_F16;
+ } else {
+ GGML_ASSERT(false);
+ }
+ if (std::is_same<float, Y_TYPE>()) {
+ src1_type = GGML_TYPE_F32;
+ } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
+ src1_type = GGML_TYPE_F16;
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch);
+ ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch);
+ ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
+
+ src0_ggml->data = x;
+ src1_ggml->data = y;
+ tensor_ggml->data = d_chk;
+
+ ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
+ ggml_build_forward_expand(cgraph, tensor_ggml);
+
+ ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
+
+ ggml_free(ggml_ctx);
+
+ double avg_err = 0.0;
+ int first_err_n = -1;
+ int first_err_m = -1;
+ int first_err_b = -1;
+
+ for (size_t i = 0; i < m*n*batch; i++) {
+ double err = std::fabs(d[i] - d_chk[i]);
+ avg_err += err;
+
+ if (err > 0.05f && first_err_n == -1) {
+ first_err_b = i / (m * n);
+ first_err_n = (i % (m * n)) / m;
+ first_err_m = (i % (m * n)) % m;
+ }
+ }
+
+ avg_err /= m * n;
+
+ std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms avg_err=" << avg_err << std::endl;
+
+ if (avg_err > 0.1) {
+ std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
+ std::cerr << "Actual result: " << std::endl << std::endl;
+ ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+ std::cerr << std::endl;
+ ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
+ std::cerr << "Expected result: " << std::endl << std::endl;
+ ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+
+ if (split_k > 1) {
+ float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
+ ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
+
+ std::cerr << "d_buf0: " << std::endl << std::endl;
+ ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+
+ std::cerr << "d_buf1: " << std::endl << std::endl;
+ ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+
+ std::cerr << "d_buf2: " << std::endl << std::endl;
+ ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+
+ std::cerr << "d_buf3: " << std::endl << std::endl;
+ ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+
+ free(split_k_buf);
+ }
+ }
+
+ free(d_chk);
+
+ ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue);
+ ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue);
+
+ ggml_vk_destroy_buffer(d_X);
+ ggml_vk_destroy_buffer(d_Y);
+ ggml_vk_destroy_buffer(d_D);
+
+ ggml_pipeline_cleanup(p);
+ ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce);
+
+ free(x);
+ free(y);
+ free(d);
+}
+
+static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
+ if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) {
+ return;
+ }
+ i0 = std::max(i0, 5);
+ i1 = std::max(i1, 5);
+ i2 = std::max(i2, 0);
+ i3 = std::max(i3, 0);
+ fprintf(stderr, " ");
+ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
+ fprintf(stderr, "%7d ", idx1);
+ }
+ fprintf(stderr, "\n");
+ for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
+ fprintf(stderr, "%7d: ", idx0);
+ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
+ if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {
+ float val;
+ if (tensor->type == GGML_TYPE_F32) {
+ val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
+ } else if (tensor->type == GGML_TYPE_F16) {
+ val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));
+ } else {
+ GGML_ASSERT(false);
+ }
+ fprintf(stderr, "% 7.2f ", val);
+ } else {
+ fprintf(stderr, " ");
+ }
+ }
+ fprintf(stderr, "\n");
+ }
+}
+
+static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) {
+ ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr);
+}
+
+static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) {
+ if (quant == GGML_TYPE_F32) {
+ memcpy(to, from, sizeof(float) * ne);
+ return;
+ }
+
+ ggml_type_traits_t tt = ggml_internal_get_type_traits(quant);
+
+ ggml_to_float_t dequant_fn = tt.to_float;
+
+ dequant_fn(from, to, ne);
+}
+
+static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
+ VK_LOG_DEBUG("ggml_vk_test_dequant(" << ne << ")");
+ const size_t x_sz = sizeof(float) * ne;
+ const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne;
+ const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
+ float * x = (float *) malloc(x_sz);
+ void * qx = malloc(qx_sz);
+ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ float * x_ref = (float *) malloc(x_sz);
+ ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
+
+ for (size_t i = 0; i < ne; i++) {
+ x[i] = rand() / (float)RAND_MAX;
+ }
+
+ vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant);
+
+ ggml_vk_quantize_data(x, qx, ne, quant);
+ ggml_vk_dequantize_data(qx, x_ref, ne, quant);
+
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, p, 1);
+
+ ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
+
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+ ggml_vk_ctx_begin(ctx->device, subctx);
+ const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
+ ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
+ ggml_vk_ctx_end(subctx);
+
+ auto begin = std::chrono::high_resolution_clock::now();
+
+ ggml_vk_submit(subctx, ctx->fence);
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
+ ctx->device->device.resetFences({ ctx->fence });
+
+ auto end = std::chrono::high_resolution_clock::now();
+
+ double ms_dequant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
+ ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16);
+
+ int first_err = -1;
+
+ double avg_err = 0.0;
+ for (size_t i = 0; i < ne; i++) {
+ double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i]));
+ avg_err += error;
+
+ if (first_err < 0 && error > 0.05) {
+ first_err = i;
+ }
+ }
+
+ avg_err /= ne;
+
+ std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl;
+
+ if (avg_err > 0.1) {
+ std::cerr << "first_error = " << first_err << std::endl;
+ std::cerr << "Actual result: " << std::endl << std::endl;
+ for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
+ std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", ";
+ }
+ std::cerr << std::endl << "Expected result: " << std::endl << std::endl;
+ for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
+ std::cerr << x_ref[i] << ", ";
+ }
+ std::cerr << std::endl;
+ }
+
+ ggml_vk_destroy_buffer(x_buf);
+ ggml_vk_destroy_buffer(qx_buf);
+
+ free(x);
+ free(qx);
+ free(x_ref);
+ free(x_chk);
+}
+
+static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) {
+ VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
+ const size_t x_ne = m * k * batch;
+ const size_t y_ne = k * n * batch;
+ const size_t d_ne = m * n * batch;
+
+ vk_pipeline p;
+ std::string shname;
+ if (shader_size == 0) {
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_s;
+ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
+ } else if (shader_size == 1) {
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_m;
+ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
+ } else if (shader_size == 2) {
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_l;
+ shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
+ } else {
+ GGML_ASSERT(0);
+ }
+
+ const size_t kpad = ggml_vk_align_size(k, p->align);
+
+ if (k != kpad) {
+ if (shader_size == 0) {
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->s;
+ shname = std::string(ggml_type_name(quant)) + "_S";
+ } else if (shader_size == 1) {
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->m;
+ shname = std::string(ggml_type_name(quant)) + "_M";
+ } else if (shader_size == 2) {
+ p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->l;
+ shname = std::string(ggml_type_name(quant)) + "_L";
+ } else {
+ GGML_ASSERT(0);
+ }
+ }
+
+ const size_t x_sz = sizeof(float) * x_ne;
+ const size_t y_sz = sizeof(float) * y_ne;
+ const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
+ const size_t d_sz = sizeof(float) * d_ne;
+ float * x = (float *) malloc(x_sz);
+ float * y = (float *) malloc(y_sz);
+ void * qx = malloc(qx_sz);
+ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ float * d = (float *) malloc(d_sz);
+ float * d_chk = (float *) malloc(d_sz);
+
+ for (size_t i = 0; i < x_ne; i++) {
+ x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
+ }
+
+ ggml_vk_quantize_data(x, qx, x_ne, quant);
+
+ for (size_t i = 0; i < y_ne; i++) {
+ // y[i] = rand() / (float)RAND_MAX;
+ y[i] = (i % k == i / k) ? 1.0f : 0.0f;
+ }
+
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, p, num_it);
+ if (split_k > 1) {
+ ggml_pipeline_allocate_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it);
+
+ if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
+ // Resize buffer
+ if (ctx->prealloc_split_k != nullptr) {
+ ggml_vk_destroy_buffer(ctx->prealloc_split_k);
+ }
+ ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
+ }
+ }
+
+ ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
+ ggml_vk_buffer_write(y_buf, 0, y, y_sz);
+
+ vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+ for (size_t i = 0; i < num_it; i++) {
+ ggml_vk_ctx_begin(ctx->device, subctx);
+ ggml_vk_matmul(
+ ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
+ m, n, k,
+ k, k, m, k*m, k*n, m*n,
+ split_k, batch, batch, batch, 1, 1
+ );
+ ggml_vk_ctx_end(subctx);
+ }
+
+ auto begin = std::chrono::high_resolution_clock::now();
+
+ ggml_vk_submit(subctx, ctx->fence);
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
+ ctx->device->device.resetFences({ ctx->fence });
+
+ auto end = std::chrono::high_resolution_clock::now();
+
+ double time_ms = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
+ ggml_vk_buffer_read(d_buf, 0, d, d_sz);
+
+ ggml_init_params iparams = {
+ /*.mem_size =*/ 1024*1024*1024,
+ /*.mem_buffer =*/ NULL,
+ /*.no_alloc =*/ true,
+ };
+
+ ggml_context * ggml_ctx = ggml_init(iparams);
+
+ ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch);
+ ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch);
+ ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
+
+ src0_ggml->data = qx;
+ src1_ggml->data = y;
+ tensor_ggml->data = d_chk;
+
+ ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
+ ggml_build_forward_expand(cgraph, tensor_ggml);
+
+ ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
+
+ ggml_free(ggml_ctx);
+
+ double avg_err = 0.0;
+ int first_err_n = -1;
+ int first_err_m = -1;
+ int first_err_b = -1;
+
+ for (size_t i = 0; i < m*n*batch; i++) {
+ double err = std::fabs(d[i] - d_chk[i]);
+ avg_err += err;
+
+ if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
+ first_err_b = i / (m * n);
+ first_err_n = (i % (m * n)) / m;
+ first_err_m = (i % (m * n)) % m;
+ }
+ }
+
+ avg_err /= m * n;
+
+ std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms avg_err=" << avg_err << std::endl;
+
+ if (avg_err > 0.01 || std::isnan(avg_err)) {
+ std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
+ std::cerr << "Actual result: " << std::endl << std::endl;
+ ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+ std::cerr << std::endl;
+ std::cerr << "Expected result: " << std::endl << std::endl;
+ ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+
+ if (split_k > 1) {
+ float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
+ ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
+
+ std::cerr << "d_buf0: " << std::endl << std::endl;
+ ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+
+ std::cerr << "d_buf1: " << std::endl << std::endl;
+ ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+
+ std::cerr << "d_buf2: " << std::endl << std::endl;
+ ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+
+ std::cerr << "d_buf3: " << std::endl << std::endl;
+ ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+
+ free(split_k_buf);
+ }
+ }
+
+ ggml_vk_destroy_buffer(qx_buf);
+ ggml_vk_destroy_buffer(y_buf);
+ ggml_vk_destroy_buffer(d_buf);
+
+ free(x);
+ free(qx);
+ free(y);
+ free(d);
+ free(d_chk);
+}
+#endif
+
+static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor) {
+ VK_LOG_DEBUG("ggml_vk_create_extra(" << tensor << " (" << tensor->name << ", " << ggml_op_name(tensor->op) << "))");
+ ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
+ extra->reset();
+ tensor->extra = extra;
+ return extra;
+}
+
+static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggml_tensor * node){
+ VK_LOG_DEBUG("ggml_vk_preallocate_buffers_graph(" << node << ")");
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra;
+
+ if (extra == nullptr) {
+ return;
+ }
+
+ ggml_tensor * src0 = node->src[0];
+ ggml_tensor * src1 = node->src[1];
+
+ const bool use_src0 = src0 != nullptr;
+ const int64_t ne00 = use_src0 ? src0->ne[0] : 0;
+ const int64_t ne01 = use_src0 ? src0->ne[1] : 0;
+ const int64_t ne02 = use_src0 ? src0->ne[2] : 0;
+ const int64_t ne03 = use_src0 ? src0->ne[3] : 0;
+ const bool use_src1 = src1 != nullptr && node->op != GGML_OP_CPY && node->op != GGML_OP_CONT && node->op != GGML_OP_DUP;
+ const int64_t ne10 = use_src1 ? src1->ne[0] : 0;
+ const int64_t ne11 = use_src1 ? src1->ne[1] : 0;
+ const int64_t ne12 = use_src1 ? src1->ne[2] : 0;
+ const int64_t ne13 = use_src1 ? src1->ne[3] : 0;
+ const int64_t ne20 = node->ne[0];
+ const int64_t ne21 = node->ne[1];
+ const int64_t ne22 = node->ne[2];
+ const int64_t ne23 = node->ne[3];
+
+ const ggml_type src0_type = (use_src0 && src0->type == GGML_TYPE_F32) ? src0->type : GGML_TYPE_F16;
+ const ggml_type src1_type = (use_src1 && src1->type == GGML_TYPE_F32) ? src1->type : GGML_TYPE_F16;
+
+ const bool x_non_contig = use_src0 && !ggml_vk_dim01_contiguous(src0);
+ const bool y_non_contig = use_src1 && !ggml_vk_dim01_contiguous(src1);
+
+ const bool y_f32_kernel = use_src1 && src1->type == GGML_TYPE_F32 && !y_non_contig;
+
+ bool mmp = (use_src0 && use_src1 && src1_type == GGML_TYPE_F32) ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0_type, y_non_contig ? GGML_TYPE_F16 : src1->type) != nullptr : false;
+
+ const bool qx_needs_dequant = use_src0 && (!mmp || x_non_contig);
+ const bool qy_needs_dequant = use_src1 && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
+
+ int split_k;
+ if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
+ split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
+ } else {
+ split_k = 1;
+ }
+ const uint32_t x_ne = ne00 * ne01;
+ const uint32_t y_ne = ne10 * ne11;
+ const uint32_t d_ne = ne20 * ne21;
+
+ const uint64_t x_sz = (use_src0 && qx_needs_dequant) ? ggml_vk_align_size(sizeof(src0_type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
+ const uint64_t y_sz = (use_src1 && qy_needs_dequant) ? ggml_vk_align_size(sizeof(src1_type) * y_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
+ uint64_t d_sz = ggml_vk_align_size(ggml_type_size(node->type) * d_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne22 * ne23;
+ const uint64_t split_k_size = split_k > 1 ? d_sz * 4 : 0;
+
+ if (extra->buffer_gpu.expired()) {
+ // Workaround for CPU backend BLAS matmul calls
+ extra->buffer_gpu = ggml_vk_create_buffer_temp(ctx, d_sz);
+ }
+
+ switch (node->op) {
+ case GGML_OP_REPEAT:
+ case GGML_OP_GET_ROWS:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_ADD:
+ case GGML_OP_SCALE:
+ case GGML_OP_SQR:
+ case GGML_OP_CLAMP:
+ case GGML_OP_CPY:
+ case GGML_OP_CONT:
+ case GGML_OP_DUP:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_NORM:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_SOFT_MAX:
+ case GGML_OP_ROPE:
+ case GGML_OP_ARGSORT:
+ case GGML_OP_SUM_ROWS:
+ break;
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(node)) {
+ case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_RELU:
+ break;
+ default:
+ return;
+ }
+ break;
+ case GGML_OP_MUL_MAT:
+ case GGML_OP_MUL_MAT_ID:
+ if (ctx->prealloc_size_x < x_sz) {
+ ctx->prealloc_size_x = x_sz;
+ }
+ if (ctx->prealloc_size_y < y_sz) {
+ ctx->prealloc_size_y = y_sz;
+ }
+ if (ctx->prealloc_size_split_k < split_k_size) {
+ ctx->prealloc_size_split_k = split_k_size;
+ }
+ if (ctx->staging_size < x_sz + y_sz) {
+ ctx->staging_size = x_sz + y_sz;
+ }
+ break;
+ default:
+ return;
+ }
+}
+
+static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
+#if defined(GGML_VULKAN_RUN_TESTS)
+ ctx->staging = ggml_vk_create_buffer_check(ctx->device, 100ul * 1024ul * 1024ul,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0);
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1);
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_0);
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_1);
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q8_0);
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q2_K);
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q3_K);
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_K);
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K);
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K);
+ ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_IQ4_NL);
+
+ ggml_vk_test_matmul<ggml_fp16_t, ggml_fp16_t>(ctx, 512, 512, 100, 32, 100, 1, 2);
+
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
+ ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
+ // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
+ // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
+ // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
+
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0);
+
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_1);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_1);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_1);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_1);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_1);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_1);
+
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_0);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_0);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_0);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_0);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_0);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_0);
+
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_1);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_1);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_1);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_1);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_1);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_1);
+
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q8_0);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q8_0);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q8_0);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q8_0);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q8_0);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q8_0);
+
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q2_K);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q2_K);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q2_K);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q2_K);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q2_K);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q2_K);
+
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q3_K);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q3_K);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q3_K);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q3_K);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q3_K);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q3_K);
+
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_K);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_K);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_K);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_K);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_K);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_K);
+
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_K);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_K);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_K);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_K);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_K);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_K);
+
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q6_K);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q6_K);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q6_K);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q6_K);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q6_K);
+ // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q6_K);
+
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_IQ4_NL);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_IQ4_NL);
+ ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_IQ4_NL);
+
+ std::cerr << std::endl;
+
+ const std::vector<size_t> vals {
+ 8, 8, 8,
+ 100, 46, 576,
+ 623, 111, 128,
+ 100, 46, 558,
+ 512, 1, 256,
+ 128, 110, 622,
+ 511, 511, 127,
+ 511, 511, 7,
+ 511, 511, 17,
+ 49, 49, 128,
+ 128, 49, 49,
+ 4096, 49, 4096,
+ 11008, 49, 4096,
+ 4096, 49, 11008,
+ 32000, 49, 4096,
+ 512, 512, 128,
+ 128, 512, 512,
+ 4096, 512, 4096,
+ 11008, 512, 4096,
+ 4096, 512, 11008,
+ 32000, 512, 4096,
+ };
+ const size_t num_it = 1;
+ for (size_t i = 0; i < vals.size(); i += 3) {
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
+ // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
+ // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
+ // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
+ std::cerr << std::endl;
+ }
+
+ GGML_ASSERT(false);
+#endif
+
+ if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) {
+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(x_size: " << ctx->prealloc_size_x << ")");
+ // Resize buffer
+ if (ctx->prealloc_x != nullptr) {
+ ggml_vk_destroy_buffer(ctx->prealloc_x);
+ }
+ ctx->prealloc_x = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_x);
+ }
+ if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) {
+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")");
+ // Resize buffer
+ if (ctx->prealloc_y != nullptr) {
+ ggml_vk_destroy_buffer(ctx->prealloc_y);
+ }
+ ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
+ }
+ if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
+ // Resize buffer
+ if (ctx->prealloc_split_k != nullptr) {
+ ggml_vk_destroy_buffer(ctx->prealloc_split_k);
+ }
+ ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
+ }
+ if (ctx->staging == nullptr || (ctx->staging_size > 0 && ctx->staging->size < ctx->staging_size)) {
+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(staging_size: " << ctx->staging_size << ")");
+ // Resize buffer
+ if (ctx->staging != nullptr) {
+ ggml_vk_destroy_buffer(ctx->staging);
+ }
+ ctx->staging = ggml_vk_create_buffer_check(ctx->device, ctx->staging_size,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
+ }
+}
+
+static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, bool last_node){
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra;
+
+ if (ggml_is_empty(node) || extra == nullptr) {
+ return;
+ }
+
+ VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
+ ctx->semaphore_idx = 0;
+ ctx->staging_offset = 0;
+
+ const ggml_tensor * src0 = node->src[0];
+ const ggml_tensor * src1 = node->src[1];
+ const ggml_tensor * src2 = node->src[2];
+
+ switch (node->op) {
+ // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_NONE:
+ return;
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(node)) {
+ case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_RELU:
+ break;
+ default:
+ return;
+ }
+ break;
+ case GGML_OP_REPEAT:
+ case GGML_OP_GET_ROWS:
+ case GGML_OP_ADD:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_SCALE:
+ case GGML_OP_SQR:
+ case GGML_OP_CLAMP:
+ case GGML_OP_CPY:
+ case GGML_OP_CONT:
+ case GGML_OP_DUP:
+ case GGML_OP_NORM:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_SOFT_MAX:
+ case GGML_OP_ROPE:
+ case GGML_OP_MUL_MAT:
+ case GGML_OP_MUL_MAT_ID:
+ case GGML_OP_ARGSORT:
+ case GGML_OP_SUM_ROWS:
+ break;
+ default:
+ std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
+ GGML_ASSERT(false);
+ return;
+ }
+
+ if (ctx->compute_ctx == nullptr) {
+ ctx->compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+ ggml_vk_ctx_begin(ctx->device, ctx->compute_ctx);
+ }
+
+ switch (node->op) {
+ case GGML_OP_REPEAT:
+ ggml_vk_repeat(ctx, ctx->compute_ctx, src0, src1, node);
+
+ break;
+ case GGML_OP_GET_ROWS:
+ ggml_vk_get_rows(ctx, ctx->compute_ctx, src0, src1, node);
+
+ break;
+ case GGML_OP_ADD:
+ ggml_vk_add(ctx, ctx->compute_ctx, src0, src1, node);
+
+ break;
+ case GGML_OP_MUL:
+ ggml_vk_mul(ctx, ctx->compute_ctx, src0, src1, node);
+
+ break;
+ case GGML_OP_DIV:
+ ggml_vk_div(ctx, ctx->compute_ctx, src0, src1, node);
+
+ break;
+ case GGML_OP_SCALE:
+ ggml_vk_scale(ctx, ctx->compute_ctx, src0, node);
+
+ break;
+ case GGML_OP_SQR:
+ ggml_vk_sqr(ctx, ctx->compute_ctx, src0, node);
+
+ break;
+ case GGML_OP_CLAMP:
+ ggml_vk_clamp(ctx, ctx->compute_ctx, src0, node);
+
+ break;
+ case GGML_OP_CPY:
+ case GGML_OP_CONT:
+ case GGML_OP_DUP:
+ ggml_vk_cpy(ctx, ctx->compute_ctx, src0, node);
+
+ break;
+ case GGML_OP_NORM:
+ ggml_vk_norm(ctx, ctx->compute_ctx, src0, node);
+
+ break;
+ case GGML_OP_RMS_NORM:
+ ggml_vk_rms_norm(ctx, ctx->compute_ctx, src0, node);
+
+ break;
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(node)) {
+ case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_RELU:
+ ggml_vk_unary(ctx, ctx->compute_ctx, src0, node);
+ break;
+ default:
+ return;
+ }
+ break;
+ case GGML_OP_DIAG_MASK_INF:
+ ggml_vk_diag_mask_inf(ctx, ctx->compute_ctx, src0, node);
+
+ break;
+ case GGML_OP_SOFT_MAX:
+ ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, node);
+
+ break;
+ case GGML_OP_ROPE:
+ ggml_vk_rope(ctx, ctx->compute_ctx, src0, src1, src2, node);
+
+ break;
+ case GGML_OP_ARGSORT:
+ ggml_vk_argsort(ctx, ctx->compute_ctx, src0, node);
+
+ break;
+ case GGML_OP_SUM_ROWS:
+ ggml_vk_sum_rows(ctx, ctx->compute_ctx, src0, node);
+
+ break;
+ case GGML_OP_MUL_MAT:
+ ggml_vk_mul_mat(ctx, ctx->compute_ctx, src0, src1, node);
+
+ break;
+ case GGML_OP_MUL_MAT_ID:
+ ggml_vk_mul_mat_id(ctx, ctx->compute_ctx, src0, src1, src2, node);
+
+ break;
+ default:
+ return;
+ }
+
+ extra->ctx_idx = ctx->compute_ctx->idx;
+
+#ifdef GGML_VULKAN_CHECK_RESULTS
+ // Force context reset on each node so that each tensor ends up in its own context
+ // and can be run and compared to its CPU equivalent separately
+ last_node = true;
+#endif
+
+ if (last_node) {
+ ggml_vk_ctx_end(ctx->compute_ctx);
+ ctx->compute_ctx->exit_tensor = node;
+ ctx->compute_ctx = nullptr;
+ }
+}
+
+static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor){
+ ggml_tensor_extra_gpu * extra = nullptr;
+
+ switch (tensor->op) {
+ case GGML_OP_ADD:
+ case GGML_OP_GET_ROWS:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_SCALE:
+ case GGML_OP_SQR:
+ case GGML_OP_CLAMP:
+ case GGML_OP_CPY:
+ case GGML_OP_CONT:
+ case GGML_OP_DUP:
+ case GGML_OP_NORM:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_SOFT_MAX:
+ case GGML_OP_ROPE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_NONE:
+ case GGML_OP_ARGSORT:
+ case GGML_OP_SUM_ROWS:
+ extra = (ggml_tensor_extra_gpu *) tensor->extra;
+
+ break;
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(tensor)) {
+ case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_RELU:
+ extra = (ggml_tensor_extra_gpu *) tensor->extra;
+ break;
+ default:
+ return false;
+ }
+ break;
+ case GGML_OP_MUL_MAT:
+ case GGML_OP_MUL_MAT_ID:
+ extra = (ggml_tensor_extra_gpu *) tensor->extra;
+
+ break;
+ default:
+ return false;
+ }
+
+ if (extra == nullptr) {
+ return false;
+ }
+
+ VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
+
+#ifdef GGML_VULKAN_CHECK_RESULTS
+ ggml_vk_check_results_0(ctx, tensor);
+#endif
+
+ vk_context& subctx = ctx->gc.contexts[extra->ctx_idx];
+
+ // Only run if ctx hasn't been submitted yet
+ if (!subctx.seqs.empty()) {
+ // Do staging buffer copies
+ for (auto& cpy : subctx.in_memcpys) {
+ memcpy(cpy.dst, cpy.src, cpy.n);
+ }
+
+ ggml_vk_submit(&subctx, ctx->fence);
+ }
+
+ if (tensor == subctx.exit_tensor) {
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
+ ctx->device->device.resetFences({ ctx->fence });
+
+ // Do staging buffer copies
+ for (auto& cpy : subctx.out_memcpys) {
+ memcpy(cpy.dst, cpy.src, cpy.n);
+ }
+ subctx.in_memcpys.clear();
+ subctx.out_memcpys.clear();
+ }
+
+ return true;
+}
+
+// Clean up after graph processing is done
+static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
+ VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
+ for (auto& buffer : ctx->gc.temp_buffers) {
+ ggml_vk_pool_free(ctx, buffer);
+ }
+ ctx->gc.temp_buffers.clear();
+
+ for (auto& pipeline : ctx->device->pipelines) {
+ if (pipeline.expired()) {
+ continue;
+ }
+
+ vk_pipeline pl = pipeline.lock();
+ ggml_pipeline_cleanup(pl);
+ }
+
+ ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue);
+ ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue);
+
+ for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
+ ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
+ }
+ ctx->gc.semaphores.clear();
+
+ for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) {
+ ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s });
+ }
+ ctx->gc.tl_semaphores.clear();
+ ctx->semaphore_idx = 0;
+
+ ctx->event_idx = 0;
+
+ for (auto& event : ctx->gc.events) {
+ ctx->device->device.resetEvent(event);
+ }
+
+ ctx->staging_offset = 0;
+
+ ctx->compute_ctx = nullptr;
+ ctx->transfer_ctx = nullptr;
+ ctx->gc.contexts.clear();
+}
+
+// Clean up on backend free
+static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
+ VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")");
+ ggml_vk_graph_cleanup(ctx);
+
+ ggml_vk_destroy_buffer(ctx->prealloc_x);
+ ggml_vk_destroy_buffer(ctx->prealloc_y);
+ ggml_vk_destroy_buffer(ctx->prealloc_split_k);
+ ggml_vk_destroy_buffer(ctx->staging);
+
+ for (auto& buffer : ctx->buffer_pool) {
+ ggml_vk_destroy_buffer(buffer);
+ }
+
+ ctx->prealloc_size_x = 0;
+ ctx->prealloc_size_y = 0;
+ ctx->prealloc_size_split_k = 0;
+ ctx->staging_size = 0;
+
+ for (auto& event : ctx->gc.events) {
+ ctx->device->device.destroyEvent(event);
+ }
+ ctx->gc.events.clear();
+
+ ctx->device->device.destroyFence(ctx->fence);
+}
+
+GGML_CALL static int ggml_vk_get_device_count() {
+ ggml_vk_instance_init();
+
+ return vk_instance.device_indices.size();
+}
+
+GGML_CALL static void ggml_vk_get_device_description(int device, char * description, size_t description_size) {
+ ggml_vk_instance_init();
+
+ std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
+
+ vk::PhysicalDeviceProperties props;
+ devices[device].getProperties(&props);
+
+ snprintf(description, description_size, "%s", props.deviceName.data());
+}
+
+// backend interface
+
+#define UNUSED GGML_UNUSED
+
+// device backend
+
+static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
+
+struct ggml_backend_vk_buffer_context {
+ vk_device_ref device;
+ vk_buffer dev_buffer;
+ ggml_tensor_extra_gpu * temp_tensor_extras = nullptr;
+ size_t temp_tensor_extra_index = 0;
+ std::string name;
+
+ ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) :
+ device(device),
+ dev_buffer(dev_buffer),
+ name(name) {
+ }
+
+ ~ggml_backend_vk_buffer_context() {
+ ggml_vk_destroy_buffer(dev_buffer);
+ if (temp_tensor_extras != nullptr) {
+ delete[] temp_tensor_extras;
+ }
+ }
+
+ ggml_tensor_extra_gpu * ggml_vk_alloc_temp_tensor_extra() {
+ if (temp_tensor_extras == nullptr) {
+ temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_VK_MAX_NODES];
+ }
+
+ size_t alloc_index = temp_tensor_extra_index;
+ temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_VK_MAX_NODES;
+ ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
+ extra->reset();
+
+ return extra;
+ }
+};
+
+GGML_CALL static const char * ggml_backend_vk_buffer_get_name(ggml_backend_buffer_t buffer) {
+ ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
+ return ctx->name.c_str();
+}
+
+GGML_CALL static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) {
+ return buffer->iface.get_name == ggml_backend_vk_buffer_get_name;
+}
+
+GGML_CALL static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()");
+ ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
+ ggml_vk_destroy_buffer(ctx->dev_buffer);
+ delete ctx;
+}
+
+GGML_CALL static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
+ return vk_ptr_base;
+
+ UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+ VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")");
+ ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
+
+ if (tensor->view_src != nullptr) {
+ GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
+ GGML_ASSERT(tensor->view_src->extra != nullptr);
+ tensor->extra = tensor->view_src->extra;
+ } else {
+ ggml_tensor_extra_gpu * extra = ctx->ggml_vk_alloc_temp_tensor_extra();
+ extra->buffer_gpu = ctx->dev_buffer;
+ extra->offset = (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
+ tensor->extra = extra;
+ }
+}
+
+GGML_CALL static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
+
+ vk_buffer buf = extra->buffer_gpu.lock();
+
+ ggml_vk_buffer_write(buf, extra->offset + tensor->view_offs + offset, data, size);
+
+ GGML_UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
+
+ vk_buffer buf = extra->buffer_gpu.lock();
+
+ ggml_vk_buffer_read(buf, extra->offset + tensor->view_offs + offset, data, size);
+
+ GGML_UNUSED(buffer);
+}
+
+GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
+ if (ggml_backend_buffer_is_vk(src->buffer)) {
+ ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra;
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+
+ vk_buffer src_buf = src_extra->buffer_gpu.lock();
+ vk_buffer dst_buf = dst_extra->buffer_gpu.lock();
+
+ ggml_vk_buffer_copy(dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src));
+
+ return true;
+ }
+ return false;
+
+ UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
+
+ ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size);
+}
+
+static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
+ /* .get_name = */ ggml_backend_vk_buffer_get_name,
+ /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_vk_buffer_get_base,
+ /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
+ /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
+ /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
+ /* .clear = */ ggml_backend_vk_buffer_clear,
+ /* .reset = */ NULL,
+};
+
+// vk buffer type
+GGML_CALL static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) {
+ ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
+
+ return ctx->name.c_str();
+}
+
+GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")");
+ ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
+
+ vk_buffer dev_buffer = nullptr;
+ try {
+ dev_buffer = ggml_vk_create_buffer_device(ctx->device, size);
+ } catch (const vk::SystemError& e) {
+ return nullptr;
+ }
+
+ ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name);
+
+ return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size);
+}
+
+GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+ ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
+ return ctx->device->properties.limits.minStorageBufferOffsetAlignment;
+}
+
+GGML_CALL static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
+ ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
+ return ctx->device->max_memory_allocation_size;
+}
+
+GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+ return ggml_nbytes(tensor);
+
+ UNUSED(buft);
+}
+
+GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) {
+ ggml_vk_instance_init();
+
+ VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")");
+
+ vk_device dev = ggml_vk_get_device(dev_num);
+
+ return &dev->buffer_type;
+}
+
+// host buffer type
+
+GGML_CALL static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
+ return GGML_VK_NAME "_Host";
+
+ UNUSED(buft);
+}
+
+GGML_CALL static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) {
+ return GGML_VK_NAME "_Host";
+
+ UNUSED(buffer);
+}
+
+GGML_CALL static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
+ ggml_vk_host_free(vk_instance.devices[0], buffer->context);
+}
+
+GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")");
+
+ size += 32; // Behave like the CPU buffer type
+ void * ptr = nullptr;
+ try {
+ ptr = ggml_vk_host_malloc(vk_instance.devices[0], size);
+ } catch (vk::SystemError& e) {
+ std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl;
+ std::cerr << "ggml_vulkan: " << e.what() << std::endl;
+ // fallback to cpu buffer
+ return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
+ }
+
+ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
+ buffer->buft = buft;
+ buffer->iface.get_name = ggml_backend_vk_host_buffer_name;
+ buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer;
+
+ return buffer;
+
+ UNUSED(buft);
+}
+
+GGML_CALL static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+ return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment;
+
+ UNUSED(buft);
+}
+
+// Should be changed to return device-specific host buffer type
+// but that probably requires changes in llama.cpp
+GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
+ static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = {
+ /* .iface = */ {
+ /* .get_name = */ ggml_backend_vk_host_buffer_type_name,
+ /* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment,
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
+ /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
+ /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
+ },
+ /* .context = */ nullptr,
+ };
+
+ // Make sure device 0 is initialized
+ ggml_vk_instance_init();
+ ggml_vk_get_device(0);
+
+ return &ggml_backend_vk_buffer_type_host;
+}
+
+
+// backend
+
+GGML_CALL static const char * ggml_backend_vk_name(ggml_backend_t backend) {
+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
+
+ return ctx->name.c_str();
+}
+
+GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend) {
+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
+ VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")");
+
+ ggml_vk_cleanup(ctx);
+
+ delete ctx;
+ delete backend;
+}
+
+GGML_CALL static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) {
+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
+
+ return &ctx->device->buffer_type;
+}
+
+GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")");
+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
+ GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
+
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
+
+ if (ctx->transfer_ctx == nullptr) {
+ // Initialize new transfer context
+ ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
+ ggml_vk_ctx_begin(ctx->device, ctx->transfer_ctx);
+ }
+
+ vk_buffer buf = extra->buffer_gpu.lock();
+
+ ggml_vk_buffer_write_async(ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset);
+}
+
+GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")");
+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
+ GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
+
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
+
+ if (ctx->transfer_ctx == nullptr) {
+ // Initialize new transfer context
+ ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
+ ggml_vk_ctx_begin(ctx->device, ctx->transfer_ctx);
+ }
+
+ vk_buffer buf = extra->buffer_gpu.lock();
+
+ ggml_vk_buffer_read_async(ctx->transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size, ctx->staging, ctx->staging_offset);
+}
+
+GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
+ VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
+ if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) {
+ ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra;
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+
+ if (ctx->transfer_ctx == nullptr) {
+ // Initialize new transfer context
+ ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
+ ggml_vk_ctx_begin(ctx->device, ctx->transfer_ctx);
+ }
+
+ vk_buffer src_buf = src_extra->buffer_gpu.lock();
+ vk_buffer dst_buf = dst_extra->buffer_gpu.lock();
+
+ ggml_vk_buffer_copy_async(ctx->transfer_ctx, dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src));
+ return true;
+ }
+
+ return false;
+}
+
+GGML_CALL static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
+ VK_LOG_DEBUG("ggml_backend_vk_synchronize()");
+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
+ if(ctx->transfer_ctx == nullptr) {
+ return;
+ }
+
+ ggml_vk_ctx_end(ctx->transfer_ctx);
+
+ for (auto& cpy : ctx->transfer_ctx->in_memcpys) {
+ memcpy(cpy.dst, cpy.src, cpy.n);
+ }
+
+ ggml_vk_submit(ctx->transfer_ctx, ctx->fence);
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
+ ctx->device->device.resetFences({ ctx->fence });
+
+ for (auto& cpy : ctx->transfer_ctx->out_memcpys) {
+ memcpy(cpy.dst, cpy.src, cpy.n);
+ }
+
+ ctx->transfer_ctx = nullptr;
+}
+
+static bool ggml_vk_is_empty(ggml_tensor * node) {
+ return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
+}
+
+GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+ VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ ggml_vk_preallocate_buffers_graph(ctx, cgraph->nodes[i]);
+ }
+ ggml_vk_preallocate_buffers(ctx);
+
+ int last_node = cgraph->n_nodes - 1;
+
+ // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
+ while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) {
+ last_node -= 1;
+ }
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ ggml_vk_build_graph(ctx,cgraph->nodes[i], i == last_node);
+ }
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ ggml_tensor * node = cgraph->nodes[i];
+
+ if (ggml_vk_is_empty(node)) {
+ continue;
+ }
+
+ bool ok = ggml_vk_compute_forward(ctx, node);
+ if (!ok) {
+ fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+ }
+#ifdef GGML_VULKAN_CHECK_RESULTS
+ else {
+ ggml_vk_check_results_1(ctx, node);
+ }
+#endif
+ GGML_ASSERT(ok);
+ }
+
+ ggml_vk_graph_cleanup(ctx);
+
+ return GGML_STATUS_SUCCESS;
+
+ UNUSED(backend);
+}
+
+GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
+ // ggml_backend_vk_context * ctx = (ggml_backend_vk_context *) backend->context;
+
+ switch (op->op) {
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(op)) {
+ case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_RELU:
+ return ggml_is_contiguous(op->src[0]);
+ default:
+ return false;
+ }
+ break;
+ case GGML_OP_MUL_MAT:
+ case GGML_OP_MUL_MAT_ID:
+ {
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ4_NL:
+ break;
+ default:
+ return false;
+ }
+ struct ggml_tensor * a;
+ struct ggml_tensor * b;
+ if (op->op == GGML_OP_MUL_MAT) {
+ a = op->src[0];
+ b = op->src[1];
+ } else {
+ a = op->src[2];
+ b = op->src[1];
+ }
+ if (a->ne[3] != b->ne[3]) {
+ return false;
+ }
+ return true;
+ } break;
+ case GGML_OP_GET_ROWS:
+ {
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_IQ4_NL:
+ return true;
+ default:
+ return false;
+ }
+ } break;
+ case GGML_OP_CPY:
+ case GGML_OP_DUP:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type;
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+ return true;
+ }
+ return false;
+ } break;
+ // case GGML_OP_REPEAT:
+ // {
+ // ggml_type src0_type = op->src[0]->type;
+ // return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+ // } break;
+ case GGML_OP_ROPE:
+ return ggml_is_contiguous(op->src[0]);
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_NORM:
+ case GGML_OP_ADD:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_SCALE:
+ case GGML_OP_SQR:
+ case GGML_OP_CLAMP:
+ case GGML_OP_CONT:
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_SOFT_MAX:
+ case GGML_OP_ARGSORT:
+ case GGML_OP_SUM_ROWS:
+ return true;
+ default:
+ return false;
+ }
+
+ UNUSED(backend);
+}
+
+GGML_CALL static bool ggml_backend_vk_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
+ const int min_batch_size = 32;
+
+ return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
+ (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
+
+ UNUSED(backend);
+}
+
+GGML_CALL static bool ggml_backend_vk_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+ if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) {
+ return false;
+ }
+
+ ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
+
+ return buft_ctx->device == ctx->device;
+}
+
+// TODO: enable async and synchronize
+static ggml_backend_i ggml_backend_vk_interface = {
+ /* .get_name = */ ggml_backend_vk_name,
+ /* .free = */ ggml_backend_vk_free,
+ /* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
+ /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async,
+ /* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async,
+ /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
+ /* .synchronize = */ NULL, // ggml_backend_vk_synchronize,
+ /* .graph_plan_create = */ NULL,
+ /* .graph_plan_free = */ NULL,
+ /* .graph_plan_update = */ NULL,
+ /* .graph_plan_compute = */ NULL,
+ /* .graph_compute = */ ggml_backend_vk_graph_compute,
+ /* .supports_op = */ ggml_backend_vk_supports_op,
+ /* .supports_buft = */ ggml_backend_vk_supports_buft,
+ /* .offload_op = */ ggml_backend_vk_offload_op,
+ /* .event_new = */ NULL,
+ /* .event_free = */ NULL,
+ /* .event_record = */ NULL,
+ /* .event_wait = */ NULL,
+ /* .event_synchronize = */ NULL,
+};
+
+static ggml_guid_t ggml_backend_vk_guid() {
+ static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
+ return &guid;
+}
+
+GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
+ VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
+
+ ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
+ ggml_vk_init(ctx, dev_num);
+
+ ggml_backend_t vk_backend = new ggml_backend {
+ /* .guid = */ ggml_backend_vk_guid(),
+ /* .interface = */ ggml_backend_vk_interface,
+ /* .context = */ ctx,
+ };
+
+ return vk_backend;
+}
+
+GGML_CALL bool ggml_backend_is_vk(ggml_backend_t backend) {
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
+}
+
+GGML_CALL int ggml_backend_vk_get_device_count() {
+ return ggml_vk_get_device_count();
+}
+
+GGML_CALL void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
+ ggml_vk_get_device_description(device, description, description_size);
+}
+
+GGML_CALL void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
+ GGML_ASSERT(device < (int) vk_instance.device_indices.size());
+
+ vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
+
+ vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
+
+ for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
+ if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
+ *total = heap.size;
+ *free = heap.size;
+ break;
+ }
+ }
+}
+
+// backend registry
+GGML_CALL static ggml_backend_t ggml_backend_reg_vk_init(const char * params, void * user_data) {
+ ggml_backend_t vk_backend = ggml_backend_vk_init((int) (intptr_t) user_data);
+ return vk_backend;
+
+ UNUSED(params);
+}
+
+extern "C" GGML_CALL int ggml_backend_vk_reg_devices();
+
+GGML_CALL int ggml_backend_vk_reg_devices() {
+ ggml_vk_instance_init();
+
+ for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
+ char name[128];
+ snprintf(name, sizeof(name), "%s%ld", GGML_VK_NAME, i);
+ ggml_backend_register(name, ggml_backend_reg_vk_init, ggml_backend_vk_buffer_type(i), (void *) (intptr_t) i); // NOLINT
+ }
+ return vk_instance.device_indices.size();
+}
+
+// Extension availability
+static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
+#ifdef GGML_VULKAN_VALIDATE
+ bool portability_enumeration_ext = false;
+ // Check for portability enumeration extension for MoltenVK support
+ for (const auto& properties : instance_extensions) {
+ if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
+ return true;
+ }
+ }
+ if (!portability_enumeration_ext) {
+ std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
+ }
+#endif
+ return false;
+
+ UNUSED(instance_extensions);
+}
+static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
+#ifdef __APPLE__
+ bool portability_enumeration_ext = false;
+ // Check for portability enumeration extension for MoltenVK support
+ for (const auto& properties : instance_extensions) {
+ if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
+ return true;
+ }
+ }
+ if (!portability_enumeration_ext) {
+ std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
+ }
+#endif
+ return false;
+
+ UNUSED(instance_extensions);
+}
+
+// checks
+
+#ifdef GGML_VULKAN_CHECK_RESULTS
+static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector<const ggml_tensor *>& done, int level = 0) {
+ if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) {
+ return;
+ }
+ for (int j = 0; j < level; j++) {
+ std::cerr << " ";
+ }
+ std::cerr << ggml_op_name(tensor->op) << " gpu=" << (tensor->extra != nullptr) << std::endl;
+
+ done.push_back(tensor);
+
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ if (tensor->src[i] != nullptr) {
+ ggml_vk_print_graph_origin(tensor->src[i], done, level + 1);
+ }
+ }
+}
+
+static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) {
+ if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) {
+ return;
+ }
+ i0 = std::max(i0, 5);
+ i1 = std::max(i1, 5);
+ i2 = std::max(i2, 0);
+ i3 = std::max(i3, 0);
+ fprintf(stderr, " ");
+ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
+ fprintf(stderr, "%7d ", idx1);
+ }
+ fprintf(stderr, "\n");
+ for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
+ fprintf(stderr, "%7d: ", idx0);
+ for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
+ if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {
+ float val;
+ if (tensor->type == GGML_TYPE_F32) {
+ val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
+ } else if (tensor->type == GGML_TYPE_F16) {
+ val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));
+ } else if (tensor->type == GGML_TYPE_I32) {
+ val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
+ } else {
+ GGML_ASSERT(false);
+ }
+ fprintf(stderr, "% 7.2f ", val);
+ } else {
+ fprintf(stderr, " ");
+ }
+ }
+ fprintf(stderr, "\n");
+ }
+}
+
+static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tensor * tensor, const char * name) {
+ void * tensor_data = tensor->data;
+
+ if (ggml_backend_buffer_is_vk(tensor->buffer)) {
+ const size_t tensor_size = ggml_nbytes(tensor);
+ tensor_data = malloc(tensor_size);
+
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
+
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
+ ggml_vk_buffer_read(buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size);
+ }
+
+ std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl;
+ std::cerr << "tensor=" << tensor << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl;
+ if (tensor->src[0] != nullptr) {
+ std::cerr << "tensor->src[0]=" << tensor->src[0] << " name=" << tensor->src[0]->name << " op=" << ggml_op_name(tensor->src[0]->op) << " type=" << ggml_type_name(tensor->src[0]->type) << " ne0=" << tensor->src[0]->ne[0] << " nb0=" << tensor->src[0]->nb[0] << " ne1=" << tensor->src[0]->ne[1] << " nb1=" << tensor->src[0]->nb[1] << " ne2=" << tensor->src[0]->ne[2] << " nb2=" << tensor->src[0]->nb[2] << " ne3=" << tensor->src[0]->ne[3] << " nb3=" << tensor->src[0]->nb[3] << std::endl;
+ }
+ if (tensor->src[1] != nullptr) {
+ std::cerr << "tensor->src[1]=" << tensor->src[1] << " name=" << tensor->src[1]->name << " op=" << ggml_op_name(tensor->src[1]->op) << " type=" << ggml_type_name(tensor->src[1]->type) << " ne0=" << tensor->src[1]->ne[0] << " nb0=" << tensor->src[1]->nb[0] << " ne1=" << tensor->src[1]->ne[1] << " nb1=" << tensor->src[1]->nb[1] << " ne2=" << tensor->src[1]->ne[2] << " nb2=" << tensor->src[1]->nb[2] << " ne3=" << tensor->src[1]->ne[3] << " nb3=" << tensor->src[1]->nb[3] << std::endl;
+ }
+ std::cerr << std::endl << "Result:" << std::endl;
+ ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
+ std::cerr << std::endl;
+ std::cerr << std::endl << "Result:" << std::endl;
+ ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 1, 0);
+ std::cerr << std::endl;
+ std::vector<const ggml_tensor *> done;
+ ggml_vk_print_graph_origin(tensor, done);
+
+ if (ggml_backend_buffer_is_vk(tensor->buffer)) {
+ free(tensor_data);
+ }
+}
+
+void * comp_result;
+size_t comp_size;
+size_t comp_nb[GGML_MAX_DIMS];
+size_t check_counter = 0;
+static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor * tensor) {
+ if (tensor->op == GGML_OP_TRANSPOSE) {
+ return;
+ }
+
+ check_counter++;
+ if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
+ return;
+ }
+
+ VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")");
+
+ ggml_tensor * src0 = tensor->src[0];
+ ggml_tensor * src1 = tensor->src[1];
+ ggml_tensor * src2 = tensor->src[2];
+
+ struct ggml_init_params iparams = {
+ /*.mem_size =*/ 1024*1024*1024,
+ /*.mem_buffer =*/ NULL,
+ /*.no_alloc =*/ false,
+ };
+
+ struct ggml_context * ggml_ctx = ggml_init(iparams);
+
+ struct ggml_tensor * src0_clone = nullptr;
+ struct ggml_tensor * src1_clone = nullptr;
+ struct ggml_tensor * src2_clone = nullptr;
+ struct ggml_tensor * tensor_clone = nullptr;
+
+ size_t src0_size;
+ size_t src1_size;
+ size_t src2_size;
+
+ void * src0_buffer = nullptr;
+ void * src1_buffer = nullptr;
+ void * src2_buffer = nullptr;
+
+ if (src0 != nullptr) {
+ src0_clone = ggml_dup_tensor(ggml_ctx, src0);
+
+ src0_size = ggml_nbytes(src0);
+
+ src0_buffer = malloc(src0_size);
+ src0_clone->data = src0_buffer;
+ if (ggml_backend_buffer_is_host(src0->buffer)) {
+ memcpy(src0_clone->data, src0->data, src0_size);
+ memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
+ } else if (ggml_backend_buffer_is_vk(src0->buffer)) {
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src0->extra;
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
+ uint64_t offset = extra->offset + src0->view_offs;
+ if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) {
+ for (int i3 = 0; i3 < src0->ne[3]; i3++) {
+ for (int i2 = 0; i2 < src0->ne[2]; i2++) {
+ const int idx = i3*src0->ne[2] + i2;
+ ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
+ }
+ }
+
+ src0_clone->nb[0] = src0->nb[0];
+ src0_clone->nb[1] = src0->nb[1];
+ for (int i = 2; i < GGML_MAX_DIMS; i++) {
+ src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1];
+ }
+ } else {
+ if (offset + src0_size >= buffer_gpu->size) {
+ src0_size = buffer_gpu->size - offset;
+ }
+ ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size);
+ memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
+ }
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
+ ggml_vk_print_tensor(ctx, src0, "src0");
+ }
+ }
+ if (src1 != nullptr) {
+ src1_clone = ggml_dup_tensor(ggml_ctx, src1);
+
+ src1_size = ggml_nbytes(src1);
+
+ src1_buffer = malloc(src1_size);
+ src1_clone->data = src1_buffer;
+ if (ggml_backend_buffer_is_host(src1->buffer)) {
+ memcpy(src1_clone->data, src1->data, src1_size);
+ memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
+ } else if (ggml_backend_buffer_is_vk(src1->buffer)) {
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src1->extra;
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
+ uint64_t offset = extra->offset + src1->view_offs;
+ if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) {
+ for (int i3 = 0; i3 < src1->ne[3]; i3++) {
+ for (int i2 = 0; i2 < src1->ne[2]; i2++) {
+ const int idx = i3*src1->ne[2] + i2;
+ ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
+ }
+ }
+
+ src1_clone->nb[0] = src1->nb[0];
+ src1_clone->nb[1] = src1->nb[1];
+ for (int i = 2; i < GGML_MAX_DIMS; i++) {
+ src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1];
+ }
+ } else {
+ if (offset + src1_size >= buffer_gpu->size) {
+ src1_size = buffer_gpu->size - offset;
+ }
+ ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size);
+ memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
+ }
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
+ ggml_vk_print_tensor(ctx, src1, "src1");
+ std::cerr << "TENSOR CHECK: " << ggml_op_name(src1_clone->op) << " (check " << check_counter << ")" << std::endl;
+ std::cerr << "src1_clone=" << tensor << " src1_clone->type: " << ggml_type_name(src1_clone->type) << " ne0=" << src1_clone->ne[0] << " nb0=" << src1_clone->nb[0] << " ne1=" << src1_clone->ne[1] << " nb1=" << src1_clone->nb[1] << " ne2=" << src1_clone->ne[2] << " nb2=" << src1_clone->nb[2] << " ne3=" << src1_clone->ne[3] << " nb3=" << src1_clone->nb[3] << std::endl;
+ if (src1->src[0] != nullptr) {
+ std::cerr << "src1->src[0]=" << src1->src[0] << " op=" << ggml_op_name(src1->src[0]->op) << " type=" << ggml_type_name(src1->src[0]->type) << " ne0=" << src1->src[0]->ne[0] << " nb0=" << src1->src[0]->nb[0] << " ne1=" << src1->src[0]->ne[1] << " nb1=" << src1->src[0]->nb[1] << " ne2=" << src1->src[0]->ne[2] << " nb2=" << src1->src[0]->nb[2] << " ne3=" << src1->src[0]->ne[3] << " nb3=" << src1->src[0]->nb[3] << std::endl;
+ }
+ if (src1->src[1] != nullptr) {
+ std::cerr << "src1->src[1]=" << src1->src[1] << " op=" << ggml_op_name(src1->src[1]->op) << " type=" << ggml_type_name(src1->src[1]->type) << " ne0=" << src1->src[1]->ne[0] << " nb0=" << src1->src[1]->nb[0] << " ne1=" << src1->src[1]->ne[1] << " nb1=" << src1->src[1]->nb[1] << " ne2=" << src1->src[1]->ne[2] << " nb2=" << src1->src[1]->nb[2] << " ne3=" << src1->src[1]->ne[3] << " nb3=" << src1->src[1]->nb[3] << std::endl;
+ }
+ std::cerr << std::endl << "Result:" << std::endl;
+ ggml_vk_print_tensor_area(src1_clone, src1_clone->data, 5, 5, 0, 0);
+ std::cerr << std::endl;
+ std::cerr << std::endl << "Result:" << std::endl;
+ ggml_vk_print_tensor_area(src1_clone, src1_clone->data, 5, 5, 1, 0);
+ std::cerr << std::endl;
+ std::vector<const ggml_tensor *> done;
+ ggml_vk_print_graph_origin(src1_clone, done);
+ }
+ }
+ if (src2 != nullptr) {
+ src2_clone = ggml_dup_tensor(ggml_ctx, src2);
+
+ src2_size = ggml_nbytes(src2);
+
+ src2_buffer = malloc(src2_size);
+ src2_clone->data = src2_buffer;
+ if (ggml_backend_buffer_is_host(src2->buffer)) {
+ memcpy(src2_clone->data, src2->data, src2_size);
+ memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
+ } else if (ggml_backend_buffer_is_vk(src2->buffer)) {
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src2->extra;
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
+ uint64_t offset = extra->offset + src2->view_offs;
+ if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) {
+ for (int i3 = 0; i3 < src2->ne[3]; i3++) {
+ for (int i2 = 0; i2 < src2->ne[2]; i2++) {
+ const int idx = i3*src2->ne[2] + i2;
+ ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
+ }
+ }
+
+ src2_clone->nb[0] = src2->nb[0];
+ src2_clone->nb[1] = src2->nb[1];
+ for (int i = 2; i < GGML_MAX_DIMS; i++) {
+ src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1];
+ }
+ } else {
+ if (offset + src2_size >= buffer_gpu->size) {
+ src2_size = buffer_gpu->size - offset;
+ }
+ ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size);
+ memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
+ }
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
+ ggml_vk_print_tensor(ctx, src2, "src2");
+ std::cerr << "TENSOR CHECK: " << ggml_op_name(src2_clone->op) << " (check " << check_counter << ")" << std::endl;
+ std::cerr << "src2_clone=" << tensor << " src2_clone->type: " << ggml_type_name(src2_clone->type) << " ne0=" << src2_clone->ne[0] << " nb0=" << src2_clone->nb[0] << " ne1=" << src2_clone->ne[1] << " nb1=" << src2_clone->nb[1] << " ne2=" << src2_clone->ne[2] << " nb2=" << src2_clone->nb[2] << " ne3=" << src2_clone->ne[3] << " nb3=" << src2_clone->nb[3] << std::endl;
+ if (src2->src[0] != nullptr) {
+ std::cerr << "src2->src[0]=" << src2->src[0] << " op=" << ggml_op_name(src2->src[0]->op) << " type=" << ggml_type_name(src2->src[0]->type) << " ne0=" << src2->src[0]->ne[0] << " nb0=" << src2->src[0]->nb[0] << " ne1=" << src2->src[0]->ne[1] << " nb1=" << src2->src[0]->nb[1] << " ne2=" << src2->src[0]->ne[2] << " nb2=" << src2->src[0]->nb[2] << " ne3=" << src2->src[0]->ne[3] << " nb3=" << src2->src[0]->nb[3] << std::endl;
+ }
+ if (src2->src[1] != nullptr) {
+ std::cerr << "src2->src[1]=" << src2->src[1] << " op=" << ggml_op_name(src2->src[1]->op) << " type=" << ggml_type_name(src2->src[1]->type) << " ne0=" << src2->src[1]->ne[0] << " nb0=" << src2->src[1]->nb[0] << " ne1=" << src2->src[1]->ne[1] << " nb1=" << src2->src[1]->nb[1] << " ne2=" << src2->src[1]->ne[2] << " nb2=" << src2->src[1]->nb[2] << " ne3=" << src2->src[1]->ne[3] << " nb3=" << src2->src[1]->nb[3] << std::endl;
+ }
+ std::cerr << std::endl << "Result:" << std::endl;
+ ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 0, 0);
+ std::cerr << std::endl;
+ std::cerr << std::endl << "Result:" << std::endl;
+ ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 1, 0);
+ std::cerr << std::endl;
+ std::vector<const ggml_tensor *> done;
+ ggml_vk_print_graph_origin(src2_clone, done);
+ }
+ }
+
+ if (tensor->op == GGML_OP_MUL_MAT) {
+ tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
+ } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
+ tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
+ } else if (tensor->op == GGML_OP_MUL) {
+ tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone);
+ } else if (tensor->op == GGML_OP_DIV) {
+ tensor_clone = ggml_div(ggml_ctx, src0_clone, src1_clone);
+ } else if (tensor->op == GGML_OP_SCALE) {
+ tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]);
+ } else if (tensor->op == GGML_OP_SQR) {
+ tensor_clone = ggml_sqr(ggml_ctx, src0_clone);
+ } else if (tensor->op == GGML_OP_CLAMP) {
+ tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
+ } else if (tensor->op == GGML_OP_ADD) {
+ tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone);
+ } else if (tensor->op == GGML_OP_NORM) {
+ tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
+ } else if (tensor->op == GGML_OP_RMS_NORM) {
+ tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
+ } else if (tensor->op == GGML_OP_SOFT_MAX) {
+ if (src1 != nullptr) {
+ tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
+ } else {
+ tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
+ }
+ } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
+ tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(int *)tensor->op_params);
+ } else if (tensor->op == GGML_OP_ROPE) {
+ const int n_dims = ((int32_t *) tensor->op_params)[1];
+ const int mode = ((int32_t *) tensor->op_params)[2];
+ //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
+ const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4];
+ float freq_base = ((float *) tensor->op_params)[5];
+ float freq_scale = ((float *) tensor->op_params)[6];
+ float ext_factor = ((float *) tensor->op_params)[7];
+ float attn_factor = ((float *) tensor->op_params)[8];
+ float beta_fast = ((float *) tensor->op_params)[9];
+ float beta_slow = ((float *) tensor->op_params)[10];
+ tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+ } else if (tensor->op == GGML_OP_UNARY) {
+ switch (ggml_get_unary_op(tensor)) {
+ case GGML_UNARY_OP_SILU:
+ tensor_clone = ggml_silu(ggml_ctx, src0_clone);
+ break;
+ case GGML_UNARY_OP_GELU:
+ tensor_clone = ggml_gelu(ggml_ctx, src0_clone);
+ break;
+ case GGML_UNARY_OP_RELU:
+ tensor_clone = ggml_relu(ggml_ctx, src0_clone);
+ break;
+ default:
+ std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
+ GGML_ASSERT(false);
+ }
+ } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
+ if (src1 == nullptr) {
+ tensor_clone = ggml_dup(ggml_ctx, src0_clone);
+ tensor_clone->type = tensor->type;
+ } else {
+ tensor_clone = ggml_cpy(ggml_ctx, src0_clone, src1_clone);
+ }
+ } else if (tensor->op == GGML_OP_CONT) {
+ tensor_clone = ggml_cont_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
+ } else if (tensor->op == GGML_OP_RESHAPE) {
+ tensor_clone = ggml_reshape_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
+ } else if (tensor->op == GGML_OP_VIEW) {
+ tensor_clone = ggml_view_4d(ggml_ctx, src0_clone, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
+ } else if (tensor->op == GGML_OP_PERMUTE) {
+ int32_t * params = (int32_t *)tensor->op_params;
+ tensor_clone = ggml_permute(ggml_ctx, src0_clone, params[0], params[1], params[2], params[3]);
+ } else if (tensor->op == GGML_OP_TRANSPOSE) {
+ tensor_clone = ggml_transpose(ggml_ctx, src0_clone);
+ } else if (tensor->op == GGML_OP_GET_ROWS) {
+ tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone);
+ } else if (tensor->op == GGML_OP_ARGSORT) {
+ tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params);
+ } else if (tensor->op == GGML_OP_SUM_ROWS) {
+ tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
+ } else {
+ std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
+ GGML_ASSERT(false);
+ }
+
+ ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
+ ggml_build_forward_expand(cgraph, tensor_clone);
+
+ ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
+
+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
+ ggml_vk_print_tensor(ctx, tensor_clone, "tensor_clone");
+ }
+
+ comp_size = ggml_nbytes(tensor_clone);
+
+ comp_result = malloc(comp_size);
+ memcpy(comp_result, tensor_clone->data, comp_size);
+ memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
+
+ if (src0 != nullptr) {
+ free(src0_buffer);
+ }
+ if (src1 != nullptr) {
+ free(src1_buffer);
+ }
+
+ ggml_free(ggml_ctx);
+}
+
+static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor * tensor) {
+ if (tensor->op == GGML_OP_TRANSPOSE) {
+ return;
+ }
+ if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
+ return;
+ }
+
+ VK_LOG_DEBUG("ggml_vk_check_results_1(" << tensor->name << ")");
+
+ ggml_tensor * src0 = tensor->src[0];
+ ggml_tensor * src1 = tensor->src[1];
+ ggml_tensor * src2 = tensor->src[2];
+
+ void * tensor_data = tensor->data;
+
+ if (ggml_backend_buffer_is_vk(tensor->buffer)) {
+ size_t tensor_size = ggml_nbytes(tensor);
+ tensor_data = malloc(tensor_size);
+
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
+
+ vk_buffer buffer_gpu = extra->buffer_gpu.lock();
+ if (extra->offset + tensor->view_offs + tensor_size >= buffer_gpu->size) {
+ tensor_size = buffer_gpu->size - (extra->offset + tensor->view_offs);
+ }
+
+ ggml_vk_buffer_read(buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size);
+ }
+
+ float first_error_result = -1.0f;
+ float first_error_correct = -1.0f;
+ std::array<int, 4> first_error = { -1, -1, -1, -1 };
+ double avg_err = 0.0;
+ size_t counter = 0;
+
+ for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
+ for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
+ for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+ for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+ const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size;
+ float correct = 0.0f;
+ float result = 0.0f;
+
+ if (buffer_size_fit) {
+ if (tensor->type == GGML_TYPE_F32) {
+ correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
+ result = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
+ } else if (tensor->type == GGML_TYPE_F16) {
+ correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
+ result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
+ } else if (tensor->type == GGML_TYPE_I32) {
+ correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
+ result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
+ } else {
+ std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
+ }
+ } else {
+ std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl;
+ GGML_ASSERT(false);
+ }
+
+ if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) {
+ std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl;
+ std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
+ if (src0 != nullptr) {
+ std::cerr << "src0=" << src0 << " src0->name=" << src0->name << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
+ }
+ if (src1 != nullptr) {
+ std::cerr << "src1=" << src1 << " src1->name=" << src1->name << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
+ }
+ if (src2 != nullptr) {
+ std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
+ }
+ std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
+ std::cerr << std::endl << "Result:" << std::endl;
+ ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
+ std::cerr << std::endl << "Correct:" << std::endl;
+ ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3);
+ std::cerr << std::endl;
+ std::vector<const ggml_tensor *> done;
+ ggml_vk_print_graph_origin(tensor, done);
+ GGML_ASSERT(false);
+ }
+ if (first_error[0] == -1 && std::fabs(correct - result) > 0.1f) {
+ first_error[0] = i0;
+ first_error[1] = i1;
+ first_error[2] = i2;
+ first_error[3] = i3;
+ first_error_result = result;
+ first_error_correct = correct;
+ }
+
+ // Special case, value is infinite, avoid NaN result in avg_err
+ // NaN also appears in results, if both are nan error is 0
+ if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) {
+ avg_err += std::fabs(correct - result);
+ }
+ counter++;
+ }
+ }
+ }
+ }
+
+ avg_err /= counter;
+
+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
+ std::cerr << "TENSOR CHECK: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
+ std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
+ if (src0 != nullptr) {
+ std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
+ }
+ if (src1 != nullptr) {
+ std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
+ }
+ if (src2 != nullptr) {
+ std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
+ }
+ std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
+ std::cerr << std::endl << "Result:" << std::endl;
+ ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
+ std::cerr << std::endl << "Correct:" << std::endl;
+ ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0);
+ std::cerr << std::endl;
+ std::cerr << std::endl << "Result:" << std::endl;
+ ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 1, 0);
+ std::cerr << std::endl << "Correct:" << std::endl;
+ ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 1, 0);
+ std::cerr << std::endl;
+ std::vector<const ggml_tensor *> done;
+ ggml_vk_print_graph_origin(tensor, done);
+ }
+
+ if (avg_err > 0.05 || std::isnan(avg_err)) {
+ std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
+ std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
+ if (src0 != nullptr) {
+ std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
+ }
+ if (src1 != nullptr) {
+ std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
+ }
+ if (src2 != nullptr) {
+ std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
+ }
+ std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
+ std::cerr << std::endl << "Result:" << std::endl;
+ ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
+ std::cerr << std::endl << "Correct:" << std::endl;
+ ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]);
+ std::cerr << std::endl;
+ std::vector<const ggml_tensor *> done;
+ ggml_vk_print_graph_origin(tensor, done);
+ GGML_ASSERT(false);
+ } else {
+ std::cerr << check_counter << " " << tensor->name << " op=" << ggml_op_name(tensor->op) << " avg_err=" << avg_err << std::endl;
+ }
+
+ free(comp_result);
+ comp_result = nullptr;
+ comp_size = 0;
+
+ if (ggml_backend_buffer_is_vk(tensor->buffer)) {
+ free(tensor_data);
+ }
+}
+#endif
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
new file mode 100644
index 00000000..95a1fc7d
--- /dev/null
+++ b/ggml/src/ggml.c
@@ -0,0 +1,22196 @@
+//
+// Copyright (C) 2023-2024 The ggml authors
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
+#define _USE_MATH_DEFINES // For M_PI on MSVC
+
+#include "ggml-impl.h"
+#include "ggml-quants.h"
+#include "ggml.h"
+#include "ggml-aarch64.h"
+#if GGML_USE_IQK_MULMAT
+#include "iqk/iqk_mul_mat.h"
+#endif
+
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#include <malloc.h> // using malloc.h with MSC/MINGW
+#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
+#include <alloca.h>
+#endif
+
+#include <assert.h>
+#include <errno.h>
+#include <time.h>
+#include <math.h>
+#include <stdlib.h>
+#include <string.h>
+#include <stdint.h>
+#include <inttypes.h>
+#include <stdio.h>
+#include <float.h>
+#include <limits.h>
+#include <stdarg.h>
+#include <signal.h>
+#if defined(__gnu_linux__)
+#include <syscall.h>
+#endif
+
+#ifdef GGML_USE_OPENMP
+#include <omp.h>
+#endif
+
+#ifdef GGML_USE_METAL
+#include <unistd.h>
+#endif
+
+#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
+#undef GGML_USE_LLAMAFILE
+#endif
+
+#ifdef GGML_USE_LLAMAFILE
+#include <llamafile/sgemm.h>
+#endif
+
+#if defined(_MSC_VER)
+// disable "possible loss of data" to avoid hundreds of casts
+// we should just be careful :)
+#pragma warning(disable: 4244 4267)
+
+// disable POSIX deprecation warnings
+// these functions are never going away, anyway
+#pragma warning(disable: 4996)
+#endif
+
+#if defined(_WIN32)
+
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+ #define NOMINMAX
+#endif
+#include <windows.h>
+
+typedef volatile LONG atomic_int;
+typedef atomic_int atomic_bool;
+typedef atomic_int atomic_flag;
+
+#define ATOMIC_FLAG_INIT 0
+
+static void atomic_store(atomic_int * ptr, LONG val) {
+ InterlockedExchange(ptr, val);
+}
+static LONG atomic_load(atomic_int * ptr) {
+ return InterlockedCompareExchange(ptr, 0, 0);
+}
+static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
+ return InterlockedExchangeAdd(ptr, inc);
+}
+static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
+ return atomic_fetch_add(ptr, -(dec));
+}
+static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
+ return InterlockedExchange(ptr, 1);
+}
+static void atomic_flag_clear(atomic_flag * ptr) {
+ InterlockedExchange(ptr, 0);
+}
+
+typedef HANDLE pthread_t;
+
+typedef DWORD thread_ret_t;
+static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
+ (void) unused;
+ HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
+ if (handle == NULL)
+ {
+ return EAGAIN;
+ }
+
+ *out = handle;
+ return 0;
+}
+
+static int pthread_join(pthread_t thread, void * unused) {
+ (void) unused;
+ int ret = (int) WaitForSingleObject(thread, INFINITE);
+ CloseHandle(thread);
+ return ret;
+}
+
+static int sched_yield (void) {
+ Sleep (0);
+ return 0;
+}
+#else
+#include <pthread.h>
+#include <stdatomic.h>
+
+typedef void * thread_ret_t;
+
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#endif
+
+typedef pthread_t ggml_thread_t;
+
+#ifdef GGML_USE_CPU_HBM
+#include <hbwmalloc.h>
+#endif
+
+#if defined(__APPLE__)
+#include <TargetConditionals.h>
+#endif
+
+#if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \
+ (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH))
+
+#include <sys/wait.h>
+
+void ggml_print_backtrace(void) {
+ /*
+ #include <execinfo.h>
+ #include <dlfcn.h>
+
+ void * trace[100];
+
+ int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0]));
+
+ backtrace_symbols_fd(trace, nptrs, STDERR_FILENO);
+ */
+
+ // backtrack_symbols does not show line numbers, use gdb instead
+ char attach[32];
+ snprintf(attach, sizeof(attach), "attach %d", getpid());
+ int pid = fork();
+ if (pid == 0) {
+ execlp("gdb", "gdb", "--batch",
+ "-ex", "set style enabled on",
+ "-ex", attach,
+ "-ex", "bt -frame-info source-and-location",
+ "-ex", "detach",
+ "-ex", "quit",
+ (char *) NULL);
+ } else {
+ waitpid(pid, NULL, 0);
+ }
+}
+#else
+void ggml_print_backtrace(void) {
+ // platform not supported
+}
+#endif
+
+#define GGML_DEBUG 0
+#define GGML_GELU_FP16
+#define GGML_GELU_QUICK_FP16
+
+#define GGML_SOFT_MAX_UNROLL 4
+#define GGML_VEC_DOT_UNROLL 2
+#define GGML_VEC_MAD_UNROLL 32
+
+//
+// logging
+//
+
+#if (GGML_DEBUG >= 1)
+#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG(...)
+#endif
+
+#if (GGML_DEBUG >= 5)
+#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_5(...)
+#endif
+
+#if (GGML_DEBUG >= 10)
+#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_10(...)
+#endif
+
+#define GGML_PRINT(...) printf(__VA_ARGS__)
+
+//
+// end of logging block
+//
+
+#ifdef GGML_USE_ACCELERATE
+// uncomment to use vDSP for soft max computation
+// note: not sure if it is actually faster
+//#define GGML_SOFT_MAX_ACCELERATE
+#endif
+
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
+#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
+#else
+inline static void * ggml_aligned_malloc(size_t size) {
+ if (size == 0) {
+ GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
+ return NULL;
+ }
+ void * aligned_memory = NULL;
+#ifdef GGML_USE_CPU_HBM
+ int result = hbw_posix_memalign(&aligned_memory, 16, size);
+#elif GGML_USE_METAL
+ int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size);
+#else
+ int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
+#endif
+ if (result != 0) {
+ // Handle allocation failure
+ const char *error_desc = "unknown allocation error";
+ switch (result) {
+ case EINVAL:
+ error_desc = "invalid alignment value";
+ break;
+ case ENOMEM:
+ error_desc = "insufficient memory";
+ break;
+ }
+ GGML_PRINT("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0));
+ GGML_ASSERT(false);
+ return NULL;
+ }
+ return aligned_memory;
+}
+#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
+#ifdef GGML_USE_CPU_HBM
+#define GGML_ALIGNED_FREE(ptr) if(NULL != ptr) hbw_free(ptr)
+#else
+#define GGML_ALIGNED_FREE(ptr) free(ptr)
+#endif
+#endif
+
+inline static void * ggml_malloc(size_t size) {
+ if (size == 0) {
+ GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for ggml_malloc!\n");
+ return NULL;
+ }
+ void * result = malloc(size);
+ if (result == NULL) {
+ GGML_PRINT("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0));
+ GGML_ASSERT(false);
+ }
+ return result;
+}
+
+// calloc
+inline static void * ggml_calloc(size_t num, size_t size) {
+ if (num == 0 || size == 0) {
+ GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for ggml_calloc!\n");
+ return NULL;
+ }
+ void * result = calloc(num, size);
+ if (result == NULL) {
+ GGML_PRINT("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0));
+ GGML_ASSERT(false);
+ }
+ return result;
+}
+
+#define GGML_MALLOC(size) ggml_malloc(size)
+#define GGML_CALLOC(num, size) ggml_calloc(num, size)
+
+#define GGML_FREE(ptr) free(ptr)
+
+#define UNUSED GGML_UNUSED
+#define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0)
+
+#if defined(GGML_USE_ACCELERATE)
+#include <Accelerate/Accelerate.h>
+#endif
+
+// floating point type used to accumulate sums
+typedef double ggml_float;
+
+#undef MIN
+#undef MAX
+
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+//
+// global data
+//
+
+// precomputed gelu table for f16 (128 KB)
+static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
+
+// precomputed quick gelu table for f16 (128 KB)
+static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
+
+// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
+float ggml_table_f32_f16[1 << 16];
+
+GGML_CALL const char * ggml_status_to_string(enum ggml_status status) {
+ switch (status) {
+ case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
+ case GGML_STATUS_FAILED: return "GGML status: error (operation failed)";
+ case GGML_STATUS_SUCCESS: return "GGML status: success";
+ case GGML_STATUS_ABORTED: return "GGML status: warning (operation aborted)";
+ }
+
+ return "GGML status: unknown";
+}
+
+float ggml_fp16_to_fp32(ggml_fp16_t x) {
+#define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml
+ return GGML_FP16_TO_FP32(x);
+}
+
+ggml_fp16_t ggml_fp32_to_fp16(float x) {
+#define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml
+ return GGML_FP32_TO_FP16(x);
+}
+
+float ggml_bf16_to_fp32(ggml_bf16_t x) {
+#define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml
+ return GGML_BF16_TO_FP32(x); // it just left shifts
+}
+
+ggml_bf16_t ggml_fp32_to_bf16(float x) {
+#define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml
+ return GGML_FP32_TO_BF16(x);
+}
+
+void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
+ for (int64_t i = 0; i < n; i++) {
+ y[i] = GGML_FP16_TO_FP32(x[i]);
+ }
+}
+
+void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
+ int64_t i = 0;
+#if defined(__F16C__)
+ for (; i + 7 < n; i += 8) {
+ __m256 x_vec = _mm256_loadu_ps(x + i);
+ __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
+ _mm_storeu_si128((__m128i *)(y + i), y_vec);
+ }
+ for(; i + 3 < n; i += 4) {
+ __m128 x_vec = _mm_loadu_ps(x + i);
+ __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
+ _mm_storel_epi64((__m128i *)(y + i), y_vec);
+ }
+#endif
+ for (; i < n; i++) {
+ y[i] = GGML_FP32_TO_FP16(x[i]);
+ }
+}
+
+void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
+ int64_t i = 0;
+#if defined(__AVX512F__)
+ for (; i + 16 <= n; i += 16) {
+ _mm512_storeu_ps(y + i,
+ _mm512_castsi512_ps(
+ _mm512_slli_epi32(
+ _mm512_cvtepu16_epi32(
+ _mm256_loadu_si256(
+ (const __m256i *)(x + i))),
+ 16)));
+ }
+#elif defined(__AVX2__)
+ for (; i + 8 <= n; i += 8) {
+ _mm256_storeu_ps(y + i,
+ _mm256_castsi256_ps(
+ _mm256_slli_epi32(
+ _mm256_cvtepu16_epi32(
+ _mm_loadu_si128(
+ (const __m128i *)(x + i))),
+ 16)));
+ }
+#endif
+ for (; i < n; i++) {
+ y[i] = GGML_BF16_TO_FP32(x[i]);
+ }
+}
+
+void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
+ int i = 0;
+#if defined(__AVX512BF16__)
+ for (; i + 32 <= n; i += 32) {
+ _mm512_storeu_si512(
+ (__m512i *)(y + i),
+ m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
+ _mm512_loadu_ps(x + i))));
+ }
+#endif
+ for (; i < n; i++) {
+ y[i] = GGML_FP32_TO_BF16(x[i]);
+ }
+}
+
+bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
+ return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
+}
+
+//
+// timing
+//
+
+#if defined(_MSC_VER) || defined(__MINGW32__)
+static int64_t timer_freq, timer_start;
+void ggml_time_init(void) {
+ LARGE_INTEGER t;
+ QueryPerformanceFrequency(&t);
+ timer_freq = t.QuadPart;
+
+ // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq
+ // and the uptime is high enough.
+ // We subtract the program start time to reduce the likelihood of that happening.
+ QueryPerformanceCounter(&t);
+ timer_start = t.QuadPart;
+}
+int64_t ggml_time_ms(void) {
+ LARGE_INTEGER t;
+ QueryPerformanceCounter(&t);
+ return ((t.QuadPart-timer_start) * 1000) / timer_freq;
+}
+int64_t ggml_time_us(void) {
+ LARGE_INTEGER t;
+ QueryPerformanceCounter(&t);
+ return ((t.QuadPart-timer_start) * 1000000) / timer_freq;
+}
+#else
+void ggml_time_init(void) {}
+int64_t ggml_time_ms(void) {
+ struct timespec ts;
+ clock_gettime(CLOCK_MONOTONIC, &ts);
+ return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000;
+}
+
+int64_t ggml_time_us(void) {
+ struct timespec ts;
+ clock_gettime(CLOCK_MONOTONIC, &ts);
+ return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000;
+}
+#endif
+
+int64_t ggml_cycles(void) {
+ return clock();
+}
+
+int64_t ggml_cycles_per_ms(void) {
+ return CLOCKS_PER_SEC/1000;
+}
+
+//
+// cross-platform UTF-8 file paths
+//
+
+#ifdef _WIN32
+static wchar_t * ggml_mbstowcs(const char * mbs) {
+ int wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, NULL, 0);
+ if (!wlen) {
+ errno = EINVAL;
+ return NULL;
+ }
+
+ wchar_t * wbuf = GGML_MALLOC(wlen * sizeof(wchar_t));
+ wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, wbuf, wlen);
+ if (!wlen) {
+ GGML_FREE(wbuf);
+ errno = EINVAL;
+ return NULL;
+ }
+
+ return wbuf;
+}
+#endif
+
+FILE * ggml_fopen(const char * fname, const char * mode) {
+#ifdef _WIN32
+ FILE * file = NULL;
+
+ // convert fname (UTF-8)
+ wchar_t * wfname = ggml_mbstowcs(fname);
+ if (wfname) {
+ // convert mode (ANSI)
+ wchar_t * wmode = GGML_MALLOC((strlen(mode) + 1) * sizeof(wchar_t));
+ wchar_t * wmode_p = wmode;
+ do {
+ *wmode_p++ = (wchar_t)*mode;
+ } while (*mode++);
+
+ // open file
+ file = _wfopen(wfname, wmode);
+
+ GGML_FREE(wfname);
+ GGML_FREE(wmode);
+ }
+
+ return file;
+#else
+ return fopen(fname, mode);
+#endif
+}
+
+//
+// cache line
+//
+
+#if defined(__cpp_lib_hardware_interference_size)
+#define CACHE_LINE_SIZE hardware_destructive_interference_size
+#else
+#if defined(__POWER9_VECTOR__)
+#define CACHE_LINE_SIZE 128
+#else
+#define CACHE_LINE_SIZE 64
+#endif
+#endif
+
+static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
+
+static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
+static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
+static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
+
+static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
+ [GGML_TYPE_I8] = {
+ .type_name = "i8",
+ .blck_size = 1,
+ .type_size = sizeof(int8_t),
+ .is_quantized = false,
+ },
+ [GGML_TYPE_I16] = {
+ .type_name = "i16",
+ .blck_size = 1,
+ .type_size = sizeof(int16_t),
+ .is_quantized = false,
+ },
+ [GGML_TYPE_I32] = {
+ .type_name = "i32",
+ .blck_size = 1,
+ .type_size = sizeof(int32_t),
+ .is_quantized = false,
+ },
+ [GGML_TYPE_I64] = {
+ .type_name = "i64",
+ .blck_size = 1,
+ .type_size = sizeof(int64_t),
+ .is_quantized = false,
+ },
+ [GGML_TYPE_F64] = {
+ .type_name = "f64",
+ .blck_size = 1,
+ .type_size = sizeof(double),
+ .is_quantized = false,
+ .nrows = 1,
+ },
+ [GGML_TYPE_F32] = {
+ .type_name = "f32",
+ .blck_size = 1,
+ .type_size = sizeof(float),
+ .is_quantized = false,
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
+ .vec_dot_type = GGML_TYPE_F32,
+ .nrows = 1,
+ },
+ [GGML_TYPE_F16] = {
+ .type_name = "f16",
+ .blck_size = 1,
+ .type_size = sizeof(ggml_fp16_t),
+ .is_quantized = false,
+ .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row,
+ .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row,
+ .from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row,
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
+ .vec_dot_type = GGML_TYPE_F16,
+ .nrows = 1,
+ },
+ [GGML_TYPE_Q4_0] = {
+ .type_name = "q4_0",
+ .blck_size = QK4_0,
+ .type_size = sizeof(block_q4_0),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q4_0,
+ .from_float = quantize_row_q4_0,
+ .from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref,
+ .vec_dot = ggml_vec_dot_q4_0_q8_0,
+ .vec_dot_type = GGML_TYPE_Q8_0,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+ .nrows = 2,
+#else
+ .nrows = 1,
+#endif
+ },
+ [GGML_TYPE_Q4_1] = {
+ .type_name = "q4_1",
+ .blck_size = QK4_1,
+ .type_size = sizeof(block_q4_1),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q4_1,
+ .from_float = quantize_row_q4_1,
+ .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref,
+ .vec_dot = ggml_vec_dot_q4_1_q8_1,
+ .vec_dot_type = GGML_TYPE_Q8_1,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+ .nrows = 2,
+#else
+ .nrows = 1,
+#endif
+ },
+ [4] = { // GGML_TYPE_Q4_2
+ .type_name = "DEPRECATED",
+ .blck_size = 0,
+ .type_size = 0,
+ .is_quantized = false,
+ .to_float = NULL,
+ .from_float = NULL,
+ .from_float_ref = NULL,
+ .vec_dot = NULL,
+ .vec_dot_type = GGML_TYPE_COUNT,
+ .nrows = 1,
+ },
+ [5] = { // GGML_TYPE_Q4_3
+ .type_name = "DEPRECATED",
+ .blck_size = 0,
+ .type_size = 0,
+ .is_quantized = false,
+ .to_float = NULL,
+ .from_float = NULL,
+ .from_float_ref = NULL,
+ .vec_dot = NULL,
+ .vec_dot_type = GGML_TYPE_COUNT,
+ .nrows = 1,
+ },
+ [GGML_TYPE_Q5_0] = {
+ .type_name = "q5_0",
+ .blck_size = QK5_0,
+ .type_size = sizeof(block_q5_0),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q5_0,
+ .from_float = quantize_row_q5_0,
+ .from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref,
+ .vec_dot = ggml_vec_dot_q5_0_q8_0,
+ .vec_dot_type = GGML_TYPE_Q8_0,
+ .nrows = 1,
+ },
+ [GGML_TYPE_Q5_1] = {
+ .type_name = "q5_1",
+ .blck_size = QK5_1,
+ .type_size = sizeof(block_q5_1),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q5_1,
+ .from_float = quantize_row_q5_1,
+ .from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref,
+ .vec_dot = ggml_vec_dot_q5_1_q8_1,
+ .vec_dot_type = GGML_TYPE_Q8_1,
+ .nrows = 1,
+ },
+ [GGML_TYPE_Q8_0] = {
+ .type_name = "q8_0",
+ .blck_size = QK8_0,
+ .type_size = sizeof(block_q8_0),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q8_0,
+ .from_float = quantize_row_q8_0,
+ .from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref,
+ .from_float_to_mat = quantize_mat_q8_0,
+ .vec_dot = ggml_vec_dot_q8_0_q8_0,
+ .vec_dot_type = GGML_TYPE_Q8_0,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+ .nrows = 2,
+#else
+ .nrows = 1,
+#endif
+ },
+ [GGML_TYPE_Q8_1] = {
+ .type_name = "q8_1",
+ .blck_size = QK8_1,
+ .type_size = sizeof(block_q8_1),
+ .is_quantized = true,
+ .from_float = quantize_row_q8_1,
+ .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
+ .vec_dot_type = GGML_TYPE_Q8_1,
+ .nrows = 1,
+ },
+ [GGML_TYPE_Q2_K] = {
+ .type_name = "q2_K",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_q2_K),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q2_K,
+ .from_float = quantize_row_q2_K,
+ .from_float_ref = (ggml_from_float_t) quantize_row_q2_K_ref,
+ .vec_dot = ggml_vec_dot_q2_K_q8_K,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
+ [GGML_TYPE_Q3_K] = {
+ .type_name = "q3_K",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_q3_K),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q3_K,
+ .from_float = quantize_row_q3_K,
+ .from_float_ref = (ggml_from_float_t) quantize_row_q3_K_ref,
+ .vec_dot = ggml_vec_dot_q3_K_q8_K,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
+ [GGML_TYPE_Q4_K] = {
+ .type_name = "q4_K",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_q4_K),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q4_K,
+ .from_float = quantize_row_q4_K,
+ .from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref,
+ .vec_dot = ggml_vec_dot_q4_K_q8_K,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
+ [GGML_TYPE_Q5_K] = {
+ .type_name = "q5_K",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_q5_K),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q5_K,
+ .from_float = quantize_row_q5_K,
+ .from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref,
+ .vec_dot = ggml_vec_dot_q5_K_q8_K,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
+ [GGML_TYPE_Q6_K] = {
+ .type_name = "q6_K",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_q6_K),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_q6_K,
+ .from_float = quantize_row_q6_K,
+ .from_float_ref = (ggml_from_float_t) quantize_row_q6_K_ref,
+ .vec_dot = ggml_vec_dot_q6_K_q8_K,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
+ [GGML_TYPE_IQ2_XXS] = {
+ .type_name = "iq2_xxs",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq2_xxs),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq2_xxs,
+ .from_float = NULL,
+ .from_float_ref = NULL,
+ .vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
+ [GGML_TYPE_IQ2_XS] = {
+ .type_name = "iq2_xs",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq2_xs),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq2_xs,
+ .from_float = NULL,
+ .from_float_ref = NULL,
+ .vec_dot = ggml_vec_dot_iq2_xs_q8_K,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
+ [GGML_TYPE_IQ3_XXS] = {
+ .type_name = "iq3_xxs",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq3_xxs),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq3_xxs,
+ .from_float = quantize_row_iq3_xxs,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_ref,
+ .vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
+ [GGML_TYPE_IQ3_S] = {
+ .type_name = "iq3_s",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq3_s),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq3_s,
+ .from_float = quantize_row_iq3_s,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq3_s_ref,
+ .vec_dot = ggml_vec_dot_iq3_s_q8_K,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
+ [GGML_TYPE_IQ2_S] = {
+ .type_name = "iq2_s",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq2_s),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq2_s,
+ .from_float = quantize_row_iq2_s,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq2_s_ref,
+ .vec_dot = ggml_vec_dot_iq2_s_q8_K,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
+ [GGML_TYPE_IQ1_S] = {
+ .type_name = "iq1_s",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq1_s),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq1_s,
+ .from_float = NULL,
+ .from_float_ref = NULL,
+ .vec_dot = ggml_vec_dot_iq1_s_q8_K,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
+ [GGML_TYPE_IQ1_M] = {
+ .type_name = "iq1_m",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq1_m),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq1_m,
+ .from_float = NULL,
+ .from_float_ref = NULL,
+ .vec_dot = ggml_vec_dot_iq1_m_q8_K,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
+ [GGML_TYPE_IQ1_BN] = {
+ .type_name = "iq1_bn",
+ .blck_size = QK_IQ1BN,
+ .type_size = sizeof(block_iq1_bn),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq1_bn,
+ .from_float = quantize_row_iq1_bn,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq1_bn_ref,
+ .vec_dot = ggml_vec_dot_iq1_bn_q8_K64,
+ .vec_dot_type = GGML_TYPE_Q8_K64,
+ .nrows = 1,
+ },
+ [GGML_TYPE_IQ2_BN] = {
+ .type_name = "iq2_bn",
+ .blck_size = QK_IQ1BN,
+ .type_size = sizeof(block_iq2_bn),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq2_bn,
+ .from_float = quantize_row_iq2_bn,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq2_bn_ref,
+ .vec_dot = ggml_vec_dot_iq2_bn_q8_K64,
+ .vec_dot_type = GGML_TYPE_Q8_K64,
+ .nrows = 1,
+ },
+ [GGML_TYPE_IQ4_NL] = {
+ .type_name = "iq4_nl",
+ .blck_size = QK4_NL,
+ .type_size = sizeof(block_iq4_nl),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq4_nl,
+ .from_float = quantize_row_iq4_nl,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref,
+ .vec_dot = ggml_vec_dot_iq4_nl_q8_0,
+ .vec_dot_type = GGML_TYPE_Q8_0,
+ .nrows = 1,
+ },
+ [GGML_TYPE_IQ4_XS] = {
+ .type_name = "iq4_xs",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_iq4_xs),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
+ .from_float = quantize_row_iq4_xs,
+ .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_ref,
+ .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
+ .vec_dot_type = GGML_TYPE_Q8_K,
+ .nrows = 1,
+ },
+ [GGML_TYPE_Q8_K] = {
+ .type_name = "q8_K",
+ .blck_size = QK_K,
+ .type_size = sizeof(block_q8_K),
+ .is_quantized = true,
+ .from_float = quantize_row_q8_K,
+ },
+ [GGML_TYPE_Q8_K64] = {
+ .type_name = "q8_K64",
+ .blck_size = 64,
+ .type_size = sizeof(block_q8_K64),
+ .is_quantized = true,
+ .from_float = quantize_row_q8_K64,
+ },
+ [GGML_TYPE_BF16] = {
+ .type_name = "bf16",
+ .blck_size = 1,
+ .type_size = sizeof(ggml_bf16_t),
+ .is_quantized = false,
+ .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
+ .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
+ .from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row,
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
+ .vec_dot_type = GGML_TYPE_BF16,
+ .nrows = 1,
+ },
+ [GGML_TYPE_Q4_0_4_4] = {
+ .type_name = "q4_0_4x4",
+ .blck_size = QK4_0,
+ .blck_size_interleave = 4,
+ .type_size = sizeof(block_q4_0),
+ .is_quantized = true,
+ .to_float = NULL,
+ .from_float = NULL,
+ .from_float_ref = NULL,
+ .vec_dot = NULL,
+ .vec_dot_type = GGML_TYPE_Q8_0,
+ .nrows = 1,
+ .ncols = 4,
+ .gemv = ggml_gemv_q4_0_4x4_q8_0,
+ .gemm = ggml_gemm_q4_0_4x4_q8_0,
+ },
+ [GGML_TYPE_Q4_0_4_8] = {
+ .type_name = "q4_0_4x8",
+ .blck_size = QK4_0,
+ .blck_size_interleave = 8,
+ .type_size = sizeof(block_q4_0),
+ .is_quantized = true,
+ .to_float = NULL,
+ .from_float = NULL,
+ .from_float_ref = NULL,
+ .vec_dot = NULL,
+ .vec_dot_type = GGML_TYPE_Q8_0,
+ .nrows = 1,
+ .ncols = 4,
+ .gemv = ggml_gemv_q4_0_4x8_q8_0,
+ .gemm = ggml_gemm_q4_0_4x8_q8_0,
+ },
+ [GGML_TYPE_Q4_0_8_8] = {
+ .type_name = "q4_0_8x8",
+ .blck_size = QK4_0,
+ .blck_size_interleave = 8,
+ .type_size = sizeof(block_q4_0),
+ .is_quantized = true,
+ .to_float = NULL,
+ .from_float = NULL,
+ .from_float_ref = NULL,
+ .vec_dot = NULL,
+ .vec_dot_type = GGML_TYPE_Q8_0,
+ .nrows = 1,
+ .ncols = 8,
+ .gemv = ggml_gemv_q4_0_8x8_q8_0,
+ .gemm = ggml_gemm_q4_0_8x8_q8_0,
+ }
+};
+
+// For internal test use
+ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
+ GGML_ASSERT(type < GGML_TYPE_COUNT);
+ return type_traits[type];
+}
+
+//
+// simd mappings
+//
+
+// we define a common set of C macros which map to specific intrinsics based on the current architecture
+// we then implement the fundamental computation operations below using only these macros
+// adding support for new architectures requires to define the corresponding SIMD macros
+//
+// GGML_F32_STEP / GGML_F16_STEP
+// number of elements to process in a single step
+//
+// GGML_F32_EPR / GGML_F16_EPR
+// number of elements to fit in a single register
+//
+
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
+
+#define GGML_SIMD
+
+// F32 NEON
+
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR 4
+
+#define GGML_F32x4 float32x4_t
+#define GGML_F32x4_ZERO vdupq_n_f32(0.0f)
+#define GGML_F32x4_SET1(x) vdupq_n_f32(x)
+#define GGML_F32x4_LOAD vld1q_f32
+#define GGML_F32x4_STORE vst1q_f32
+#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
+#define GGML_F32x4_ADD vaddq_f32
+#define GGML_F32x4_MUL vmulq_f32
+#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
+#define GGML_F32x4_REDUCE(res, x) \
+{ \
+ int offset = GGML_F32_ARR >> 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = vaddq_f32(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = vaddq_f32(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = vaddq_f32(x[i], x[offset+i]); \
+ } \
+ res = GGML_F32x4_REDUCE_ONE(x[0]); \
+}
+
+#define GGML_F32_VEC GGML_F32x4
+#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 NEON
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+ #define GGML_F16_STEP 32
+ #define GGML_F16_EPR 8
+
+ #define GGML_F16x8 float16x8_t
+ #define GGML_F16x8_ZERO vdupq_n_f16(0.0f)
+ #define GGML_F16x8_SET1(x) vdupq_n_f16(x)
+ #define GGML_F16x8_LOAD(x) vld1q_f16((const ggml_fp16_internal_t *)(x))
+ #define GGML_F16x8_STORE vst1q_f16
+ #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
+ #define GGML_F16x8_ADD vaddq_f16
+ #define GGML_F16x8_MUL vmulq_f16
+ #define GGML_F16x8_REDUCE(res, x) \
+ do { \
+ int offset = GGML_F16_ARR >> 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = vaddq_f16(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = vaddq_f16(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = vaddq_f16(x[i], x[offset+i]); \
+ } \
+ const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
+ const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
+ res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
+ } while (0)
+
+ #define GGML_F16_VEC GGML_F16x8
+ #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
+ #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i])
+ #define GGML_F16_VEC_FMA GGML_F16x8_FMA
+ #define GGML_F16_VEC_ADD GGML_F16x8_ADD
+ #define GGML_F16_VEC_MUL GGML_F16x8_MUL
+ #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE
+#else
+ // if FP16 vector arithmetic is not supported, we use FP32 instead
+ // and take advantage of the vcvt_ functions to convert to/from FP16
+
+ #define GGML_F16_STEP 16
+ #define GGML_F16_EPR 4
+
+ #define GGML_F32Cx4 float32x4_t
+ #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f)
+ #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x)
+ #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const ggml_fp16_internal_t *)(x)))
+ #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
+ #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
+ #define GGML_F32Cx4_ADD vaddq_f32
+ #define GGML_F32Cx4_MUL vmulq_f32
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
+
+ #define GGML_F16_VEC GGML_F32Cx4
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
+#endif
+
+#elif defined(__AVX512F__)
+
+#define GGML_SIMD
+
+// F32 AVX512
+
+#define GGML_F32_STEP 64
+#define GGML_F32_EPR 16
+
+#define GGML_F32x16 __m512
+#define GGML_F32x16_ZERO _mm512_setzero_ps()
+#define GGML_F32x16_SET1(x) _mm512_set1_ps(x)
+#define GGML_F32x16_LOAD _mm512_loadu_ps
+#define GGML_F32x16_STORE _mm512_storeu_ps
+// _mm512_fmadd_ps is defined in AVX512F so no guard is required
+#define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
+#define GGML_F32x16_ADD _mm512_add_ps
+#define GGML_F32x16_MUL _mm512_mul_ps
+#define GGML_F32x16_REDUCE(res, x) \
+do { \
+ int offset = GGML_F32_ARR >> 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
+ } \
+ res = _mm512_reduce_add_ps(x[0]); \
+} while (0)
+
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC GGML_F32x16
+#define GGML_F32_VEC_ZERO GGML_F32x16_ZERO
+#define GGML_F32_VEC_SET1 GGML_F32x16_SET1
+#define GGML_F32_VEC_LOAD GGML_F32x16_LOAD
+#define GGML_F32_VEC_STORE GGML_F32x16_STORE
+#define GGML_F32_VEC_FMA GGML_F32x16_FMA
+#define GGML_F32_VEC_ADD GGML_F32x16_ADD
+#define GGML_F32_VEC_MUL GGML_F32x16_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE
+
+// F16 AVX512
+
+// F16 AVX
+
+#define GGML_F16_STEP 64
+#define GGML_F16_EPR 16
+
+// AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead
+
+#define GGML_F32Cx16 __m512
+#define GGML_F32Cx16_ZERO _mm512_setzero_ps()
+#define GGML_F32Cx16_SET1(x) _mm512_set1_ps(x)
+
+// unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
+// so F16C guard isn't required
+#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
+#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
+
+#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
+#define GGML_F32Cx16_ADD _mm512_add_ps
+#define GGML_F32Cx16_MUL _mm512_mul_ps
+#define GGML_F32Cx16_REDUCE(res, x) \
+do { \
+ int offset = GGML_F32_ARR >> 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = _mm512_add_ps(x[i], x[offset+i]); \
+ } \
+ res = _mm512_reduce_add_ps(x[0]); \
+} while (0)
+
+#define GGML_F16_VEC GGML_F32Cx16
+#define GGML_F16_VEC_ZERO GGML_F32Cx16_ZERO
+#define GGML_F16_VEC_SET1 GGML_F32Cx16_SET1
+#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx16_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i])
+#define GGML_F16_VEC_FMA GGML_F32Cx16_FMA
+#define GGML_F16_VEC_ADD GGML_F32Cx16_ADD
+#define GGML_F16_VEC_MUL GGML_F32Cx16_MUL
+#define GGML_F16_VEC_REDUCE GGML_F32Cx16_REDUCE
+
+#elif defined(__AVX__)
+
+#define GGML_SIMD
+
+// F32 AVX
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR 8
+
+#define GGML_F32x8 __m256
+#define GGML_F32x8_ZERO _mm256_setzero_ps()
+#define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
+#define GGML_F32x8_LOAD _mm256_loadu_ps
+#define GGML_F32x8_STORE _mm256_storeu_ps
+#if defined(__FMA__)
+ #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
+#else
+ #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
+#endif
+#define GGML_F32x8_ADD _mm256_add_ps
+#define GGML_F32x8_MUL _mm256_mul_ps
+#define GGML_F32x8_REDUCE(res, x) \
+do { \
+ int offset = GGML_F32_ARR >> 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = _mm256_add_ps(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = _mm256_add_ps(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = _mm256_add_ps(x[i], x[offset+i]); \
+ } \
+ const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \
+ _mm256_extractf128_ps(x[0], 1)); \
+ const __m128 t1 = _mm_hadd_ps(t0, t0); \
+ res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \
+} while (0)
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC GGML_F32x8
+#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
+#define GGML_F32_VEC_SET1 GGML_F32x8_SET1
+#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
+#define GGML_F32_VEC_STORE GGML_F32x8_STORE
+#define GGML_F32_VEC_FMA GGML_F32x8_FMA
+#define GGML_F32_VEC_ADD GGML_F32x8_ADD
+#define GGML_F32_VEC_MUL GGML_F32x8_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
+
+// F16 AVX
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR 8
+
+// F16 arithmetic is not supported by AVX, so we use F32 instead
+
+#define GGML_F32Cx8 __m256
+#define GGML_F32Cx8_ZERO _mm256_setzero_ps()
+#define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x)
+
+#if defined(__F16C__)
+// the _mm256_cvt intrinsics require F16C
+#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
+#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
+#else
+static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
+ float tmp[8];
+
+ for (int i = 0; i < 8; i++) {
+ tmp[i] = GGML_FP16_TO_FP32(x[i]);
+ }
+
+ return _mm256_loadu_ps(tmp);
+}
+static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
+ float arr[8];
+
+ _mm256_storeu_ps(arr, y);
+
+ for (int i = 0; i < 8; i++)
+ x[i] = GGML_FP32_TO_FP16(arr[i]);
+}
+#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
+#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
+#endif
+
+#define GGML_F32Cx8_FMA GGML_F32x8_FMA
+#define GGML_F32Cx8_ADD _mm256_add_ps
+#define GGML_F32Cx8_MUL _mm256_mul_ps
+#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
+
+#define GGML_F16_VEC GGML_F32Cx8
+#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
+#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
+#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
+#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
+#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
+#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
+#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
+
+#elif defined(__POWER9_VECTOR__)
+
+#define GGML_SIMD
+
+// F32 POWER9
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR 4
+
+#define GGML_F32x4 vector float
+#define GGML_F32x4_ZERO 0.0f
+#define GGML_F32x4_SET1 vec_splats
+#define GGML_F32x4_LOAD(p) vec_xl(0, p)
+#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
+#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
+#define GGML_F32x4_ADD vec_add
+#define GGML_F32x4_MUL vec_mul
+#define GGML_F32x4_REDUCE(res, x) \
+{ \
+ int offset = GGML_F32_ARR >> 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = vec_add(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = vec_add(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = vec_add(x[i], x[offset+i]); \
+ } \
+ res = vec_extract(x[0], 0) + \
+ vec_extract(x[0], 1) + \
+ vec_extract(x[0], 2) + \
+ vec_extract(x[0], 3); \
+}
+
+#define GGML_F32_VEC GGML_F32x4
+#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 POWER9
+#define GGML_F16_STEP GGML_F32_STEP
+#define GGML_F16_EPR GGML_F32_EPR
+#define GGML_F16_VEC GGML_F32x4
+#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
+#define GGML_F16_VEC_SET1 GGML_F32x4_SET1
+#define GGML_F16_VEC_FMA GGML_F32x4_FMA
+#define GGML_F16_VEC_ADD GGML_F32x4_ADD
+#define GGML_F16_VEC_MUL GGML_F32x4_MUL
+#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
+// Use vec_xl, not vec_ld, in case the load address is not aligned.
+#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \
+ vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
+ vec_extract_fp32_from_shortl(vec_xl(0, p))
+#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
+#define GGML_F16_VEC_STORE(p, r, i) \
+ if (i & 0x1) \
+ vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \
+ r[i - GGML_ENDIAN_BYTE(0)]), \
+ 0, p - GGML_F16_EPR)
+
+#elif defined(__wasm_simd128__)
+
+#define GGML_SIMD
+
+// F32 WASM
+
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR 4
+
+#define GGML_F32x4 v128_t
+#define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f)
+#define GGML_F32x4_SET1(x) wasm_f32x4_splat(x)
+#define GGML_F32x4_LOAD wasm_v128_load
+#define GGML_F32x4_STORE wasm_v128_store
+#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
+#define GGML_F32x4_ADD wasm_f32x4_add
+#define GGML_F32x4_MUL wasm_f32x4_mul
+#define GGML_F32x4_REDUCE(res, x) \
+{ \
+ int offset = GGML_F32_ARR >> 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
+ } \
+ res = wasm_f32x4_extract_lane(x[0], 0) + \
+ wasm_f32x4_extract_lane(x[0], 1) + \
+ wasm_f32x4_extract_lane(x[0], 2) + \
+ wasm_f32x4_extract_lane(x[0], 3); \
+}
+
+#define GGML_F32_VEC GGML_F32x4
+#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 WASM
+
+#define GGML_F16_STEP 16
+#define GGML_F16_EPR 4
+
+inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
+ float tmp[4];
+
+ tmp[0] = GGML_FP16_TO_FP32(p[0]);
+ tmp[1] = GGML_FP16_TO_FP32(p[1]);
+ tmp[2] = GGML_FP16_TO_FP32(p[2]);
+ tmp[3] = GGML_FP16_TO_FP32(p[3]);
+
+ return wasm_v128_load(tmp);
+}
+
+inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
+ float tmp[4];
+
+ wasm_v128_store(tmp, x);
+
+ p[0] = GGML_FP32_TO_FP16(tmp[0]);
+ p[1] = GGML_FP32_TO_FP16(tmp[1]);
+ p[2] = GGML_FP32_TO_FP16(tmp[2]);
+ p[3] = GGML_FP32_TO_FP16(tmp[3]);
+}
+
+#define GGML_F16x4 v128_t
+#define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f)
+#define GGML_F16x4_SET1(x) wasm_f32x4_splat(x)
+#define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x)
+#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
+#define GGML_F16x4_FMA GGML_F32x4_FMA
+#define GGML_F16x4_ADD wasm_f32x4_add
+#define GGML_F16x4_MUL wasm_f32x4_mul
+#define GGML_F16x4_REDUCE(res, x) \
+{ \
+ int offset = GGML_F16_ARR >> 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
+ } \
+ res = wasm_f32x4_extract_lane(x[0], 0) + \
+ wasm_f32x4_extract_lane(x[0], 1) + \
+ wasm_f32x4_extract_lane(x[0], 2) + \
+ wasm_f32x4_extract_lane(x[0], 3); \
+}
+
+#define GGML_F16_VEC GGML_F16x4
+#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO
+#define GGML_F16_VEC_SET1 GGML_F16x4_SET1
+#define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA GGML_F16x4_FMA
+#define GGML_F16_VEC_ADD GGML_F16x4_ADD
+#define GGML_F16_VEC_MUL GGML_F16x4_MUL
+#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE
+
+#elif defined(__SSE3__)
+
+#define GGML_SIMD
+
+// F32 SSE
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR 4
+
+#define GGML_F32x4 __m128
+#define GGML_F32x4_ZERO _mm_setzero_ps()
+#define GGML_F32x4_SET1(x) _mm_set1_ps(x)
+#define GGML_F32x4_LOAD _mm_loadu_ps
+#define GGML_F32x4_STORE _mm_storeu_ps
+#if defined(__FMA__)
+ // TODO: Does this work?
+ #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
+#else
+ #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
+#endif
+#define GGML_F32x4_ADD _mm_add_ps
+#define GGML_F32x4_MUL _mm_mul_ps
+#define GGML_F32x4_REDUCE(res, x) \
+{ \
+ int offset = GGML_F32_ARR >> 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = _mm_add_ps(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = _mm_add_ps(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = _mm_add_ps(x[i], x[offset+i]); \
+ } \
+ const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \
+ res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \
+}
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC GGML_F32x4
+#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 SSE
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR 4
+
+static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
+ float tmp[4];
+
+ tmp[0] = GGML_FP16_TO_FP32(x[0]);
+ tmp[1] = GGML_FP16_TO_FP32(x[1]);
+ tmp[2] = GGML_FP16_TO_FP32(x[2]);
+ tmp[3] = GGML_FP16_TO_FP32(x[3]);
+
+ return _mm_loadu_ps(tmp);
+}
+
+static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
+ float arr[4];
+
+ _mm_storeu_ps(arr, y);
+
+ x[0] = GGML_FP32_TO_FP16(arr[0]);
+ x[1] = GGML_FP32_TO_FP16(arr[1]);
+ x[2] = GGML_FP32_TO_FP16(arr[2]);
+ x[3] = GGML_FP32_TO_FP16(arr[3]);
+}
+
+#define GGML_F32Cx4 __m128
+#define GGML_F32Cx4_ZERO _mm_setzero_ps()
+#define GGML_F32Cx4_SET1(x) _mm_set1_ps(x)
+#define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x)
+#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
+#define GGML_F32Cx4_FMA GGML_F32x4_FMA
+#define GGML_F32Cx4_ADD _mm_add_ps
+#define GGML_F32Cx4_MUL _mm_mul_ps
+#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
+
+#define GGML_F16_VEC GGML_F32Cx4
+#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
+#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
+#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
+#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
+#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
+#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
+
+#elif defined(__loongarch_asx)
+
+#define GGML_SIMD
+
+// F32 LASX
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR 8
+
+#define GGML_F32x8 __m256
+#define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0)
+#define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
+#define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
+#define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0)
+#define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
+#define GGML_F32x8_ADD __lasx_xvfadd_s
+#define GGML_F32x8_MUL __lasx_xvfmul_s
+#define GGML_F32x8_REDUCE(res, x) \
+do { \
+ int offset = GGML_F32_ARR >> 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
+ } \
+ float *tmp_p = (float *)&x[0]; \
+ res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \
+} while (0)
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC GGML_F32x8
+#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
+#define GGML_F32_VEC_SET1 GGML_F32x8_SET1
+#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
+#define GGML_F32_VEC_STORE GGML_F32x8_STORE
+#define GGML_F32_VEC_FMA GGML_F32x8_FMA
+#define GGML_F32_VEC_ADD GGML_F32x8_ADD
+#define GGML_F32_VEC_MUL GGML_F32x8_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
+
+// F16 LASX
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR 8
+
+// F16 arithmetic is not supported by AVX, so we use F32 instead
+
+#define GGML_F32Cx8 __m256
+#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
+#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
+
+static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
+ float tmp[8];
+
+ for (int i = 0; i < 8; i++) {
+ tmp[i] = GGML_FP16_TO_FP32(x[i]);
+ }
+
+ return (__m256)__lasx_xvld(tmp, 0);
+}
+static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
+ float arr[8];
+
+ __lasx_xvst(y, arr, 0);
+
+ for (int i = 0; i < 8; i++) {
+ x[i] = GGML_FP32_TO_FP16(arr[i]);
+ }
+}
+#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
+#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
+
+#define GGML_F32Cx8_FMA GGML_F32x8_FMA
+#define GGML_F32Cx8_ADD __lasx_xvfadd_s
+#define GGML_F32Cx8_MUL __lasx_xvfmul_s
+#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
+
+#define GGML_F16_VEC GGML_F32Cx8
+#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
+#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
+#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
+#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
+#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
+#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
+#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
+
+#elif defined(__loongarch_sx)
+
+#define GGML_SIMD
+
+// F32 LSX
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR 4
+
+#define GGML_F32x4 __m128
+#define GGML_F32x4_ZERO __lsx_vldi(0)
+#define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
+#define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
+#define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0)
+#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
+#define GGML_F32x4_ADD __lsx_vfadd_s
+#define GGML_F32x4_MUL __lsx_vfmul_s
+#define GGML_F32x4_REDUCE(res, x) \
+{ \
+ int offset = GGML_F32_ARR >> 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
+ } \
+ offset >>= 1; \
+ for (int i = 0; i < offset; ++i) { \
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
+ } \
+ __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
+ const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
+ tmp = __lsx_vsrli_d((__m128i)t0, 32); \
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
+ res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
+}
+
+#define GGML_F32_VEC GGML_F32x4
+#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 LSX
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR 4
+
+static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
+ float tmp[4];
+
+ tmp[0] = GGML_FP16_TO_FP32(x[0]);
+ tmp[1] = GGML_FP16_TO_FP32(x[1]);
+ tmp[2] = GGML_FP16_TO_FP32(x[2]);
+ tmp[3] = GGML_FP16_TO_FP32(x[3]);
+
+ return __lsx_vld(tmp, 0);
+}
+
+static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
+ float arr[4];
+
+ __lsx_vst(y, arr, 0);
+
+ x[0] = GGML_FP32_TO_FP16(arr[0]);
+ x[1] = GGML_FP32_TO_FP16(arr[1]);
+ x[2] = GGML_FP32_TO_FP16(arr[2]);
+ x[3] = GGML_FP32_TO_FP16(arr[3]);
+}
+
+#define GGML_F32Cx4 __m128
+#define GGML_F32Cx4_ZERO __lsx_vldi(0)
+#define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
+#define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
+#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
+#define GGML_F32Cx4_FMA GGML_F32x4_FMA
+#define GGML_F32Cx4_ADD __lsx_vfadd_s
+#define GGML_F32Cx4_MUL __lsx_vfmul_s
+#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
+
+#define GGML_F16_VEC GGML_F32Cx4
+#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
+#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
+#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
+#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
+#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
+#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
+
+#endif
+
+// GGML_F32_ARR / GGML_F16_ARR
+// number of registers to use per step
+#ifdef GGML_SIMD
+#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
+#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
+#endif
+
+//
+// ggml context
+//
+
+struct ggml_context {
+ size_t mem_size;
+ void* mem_buffer;
+ bool mem_buffer_owned;
+ bool no_alloc;
+ bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
+
+ int n_objects;
+
+ struct ggml_object * objects_begin;
+ struct ggml_object * objects_end;
+
+ struct ggml_scratch scratch;
+ struct ggml_scratch scratch_save;
+};
+
+struct ggml_context_container {
+ bool used;
+
+ struct ggml_context context;
+};
+
+struct ggml_compute_state_shared {
+ const struct ggml_cgraph * cgraph;
+ const struct ggml_cplan * cplan;
+
+ int n_threads;
+
+ // synchronization primitives
+ atomic_int n_barrier;
+ atomic_int n_barrier_passed;
+
+ ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
+ void * abort_callback_data;
+
+ atomic_int current_chunk; // currently processing chunk during mul_mat, shared between all the threads
+
+ enum ggml_status ec;
+};
+
+struct ggml_compute_state {
+ ggml_thread_t thrd;
+ int ith;
+ struct ggml_compute_state_shared * shared;
+};
+
+struct ggml_compute_params {
+ // ith = thread index, nth = number of threads
+ int ith, nth;
+
+ // work buffer for all threads
+ size_t wsize;
+ void * wdata;
+
+ struct ggml_compute_state_shared * shared;
+};
+
+//
+// fundamental operations
+//
+
+inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
+inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
+inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
+inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
+inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
+inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
+inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
+inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
+inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
+inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
+
+static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+#if defined(GGML_SIMD)
+ float sumf = 0.0f;
+ const int np = (n & ~(GGML_F32_STEP - 1));
+
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+
+ GGML_F32_VEC ax[GGML_F32_ARR];
+ GGML_F32_VEC ay[GGML_F32_ARR];
+
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
+ for (int j = 0; j < GGML_F32_ARR; j++) {
+ ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
+ }
+ }
+
+ // reduce sum0..sum3 to sum0
+ GGML_F32_VEC_REDUCE(sumf, sum);
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ sumf += x[i]*y[i];
+ }
+#else
+ // scalar
+ ggml_float sumf = 0.0;
+ for (int i = 0; i < n; ++i) {
+ sumf += (ggml_float)(x[i]*y[i]);
+ }
+#endif
+
+ *s = sumf;
+}
+
+static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+ int i = 0;
+ ggml_float sumf = 0;
+
+#if defined(__AVX512BF16__)
+ __m512 c1 = _mm512_setzero_ps();
+ __m512 c2 = _mm512_setzero_ps();
+ for (; i + 64 <= n; i += 64) {
+ c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
+ m512bh(_mm512_loadu_si512((y + i))));
+ c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
+ m512bh(_mm512_loadu_si512((y + i + 32))));
+ }
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
+
+#elif defined(__AVX512F__)
+#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
+ __m512 c1 = _mm512_setzero_ps();
+ __m512 c2 = _mm512_setzero_ps();
+ for (; i + 32 <= n; i += 32) {
+ c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
+ c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
+ }
+ sumf += (ggml_float)_mm512_reduce_add_ps(c1);
+ sumf += (ggml_float)_mm512_reduce_add_ps(c2);
+
+#undef LOAD
+#elif defined(__AVX2__)
+#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
+ __m256 c1 = _mm256_setzero_ps();
+ __m256 c2 = _mm256_setzero_ps();
+ __m256 c3 = _mm256_setzero_ps();
+ __m256 c4 = _mm256_setzero_ps();
+ for (; i + 32 <= n; i += 32) {
+ c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
+ c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
+ c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
+ c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
+ }
+ __m128 g;
+ c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
+ _mm256_add_ps(c2, c4));
+ g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
+ _mm256_castps256_ps128(c1));
+ g = _mm_add_ps(g, _mm_movehl_ps(g, g));
+ g = _mm_add_ss(g, _mm_movehdup_ps(g));
+ sumf += (ggml_float)_mm_cvtss_f32(g);
+
+#undef LOAD
+#endif
+
+ for (; i < n; ++i) {
+ sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
+ GGML_BF16_TO_FP32(y[i]));
+ }
+ *s = sumf;
+}
+
+static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+
+ ggml_float sumf = 0.0;
+
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F16_STEP - 1));
+
+ GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
+
+ GGML_F16_VEC ax[GGML_F16_ARR];
+ GGML_F16_VEC ay[GGML_F16_ARR];
+
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
+ for (int j = 0; j < GGML_F16_ARR; j++) {
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+
+ sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
+ }
+ }
+
+ // reduce sum0..sum3 to sum0
+ GGML_F16_VEC_REDUCE(sumf, sum);
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
+ }
+#else
+ for (int i = 0; i < n; ++i) {
+ sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
+ }
+#endif
+
+ *s = sumf;
+}
+
+// compute GGML_VEC_DOT_UNROLL dot products at once
+// xs - x row stride in bytes
+inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
+ ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
+
+ ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
+
+ for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+ x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
+ }
+
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F16_STEP - 1));
+
+ GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
+
+ GGML_F16_VEC ax[GGML_F16_ARR];
+ GGML_F16_VEC ay[GGML_F16_ARR];
+
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
+ for (int j = 0; j < GGML_F16_ARR; j++) {
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+
+ for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+ ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
+
+ sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
+ }
+ }
+ }
+
+ // reduce sum0..sum3 to sum0
+ for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+ GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
+ }
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+ sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
+ }
+ }
+#else
+ for (int i = 0; i < n; ++i) {
+ for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+ sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
+ }
+ }
+#endif
+
+ for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+ s[i] = sumf[i];
+ }
+}
+
+inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F32_STEP - 1));
+
+ GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+ GGML_F32_VEC ax[GGML_F32_ARR];
+ GGML_F32_VEC ay[GGML_F32_ARR];
+
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
+ for (int j = 0; j < GGML_F32_ARR; j++) {
+ ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+ ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
+
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+ }
+ }
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ y[i] += x[i]*v;
+ }
+#else
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] += x[i]*v;
+ }
+#endif
+}
+
+inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F16_STEP - 1));
+
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+ GGML_F16_VEC ax[GGML_F16_ARR];
+ GGML_F16_VEC ay[GGML_F16_ARR];
+
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
+ for (int j = 0; j < GGML_F16_ARR; j++) {
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
+
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+ }
+ }
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+ }
+#else
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+ }
+#endif
+}
+
+// xs and vs are byte strides of x and v
+inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
+
+ const float * restrict x[GGML_VEC_MAD_UNROLL];
+ const float * restrict v[GGML_VEC_MAD_UNROLL];
+
+ for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
+ x[i] = (const float *) ((const char *) xv + i*xs);
+ v[i] = (const float *) ((const char *) vv + i*vs);
+ }
+
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F32_STEP - 1));
+
+ GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
+
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+ vx[k] = GGML_F32_VEC_SET1(v[k][0]);
+ }
+
+ GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
+ GGML_F32_VEC ay[GGML_F32_ARR];
+
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
+ for (int j = 0; j < GGML_F32_ARR; j++) {
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+ ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
+ ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
+ }
+
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+ }
+ }
+
+ // leftovers
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+ for (int i = np; i < n; ++i) {
+ y[i] += x[k][i]*v[k][0];
+ }
+ }
+#else
+ // scalar
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+ for (int i = 0; i < n; ++i) {
+ y[i] += x[k][i]*v[k][0];
+ }
+ }
+#endif
+}
+
+//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
+inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
+#if defined(GGML_USE_ACCELERATE)
+ vDSP_vsmul(y, 1, &v, y, 1, n);
+#elif defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F32_STEP - 1));
+
+ GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+ GGML_F32_VEC ay[GGML_F32_ARR];
+
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
+ for (int j = 0; j < GGML_F32_ARR; j++) {
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+ ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
+
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+ }
+ }
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ y[i] *= v;
+ }
+#else
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] *= v;
+ }
+#endif
+}
+
+inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
+#if defined(GGML_SIMD)
+ const int np = (n & ~(GGML_F16_STEP - 1));
+
+ GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+ GGML_F16_VEC ay[GGML_F16_ARR];
+
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
+ for (int j = 0; j < GGML_F16_ARR; j++) {
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+ ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
+
+ GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+ }
+ }
+
+ // leftovers
+ for (int i = np; i < n; ++i) {
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+ }
+#else
+ // scalar
+ for (int i = 0; i < n; ++i) {
+ y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+ }
+#endif
+}
+
+inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
+inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
+inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
+inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
+inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
+inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
+inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
+inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
+inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
+inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
+inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
+inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
+// TODO: optimize performance
+inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
+inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
+
+static const float GELU_COEF_A = 0.044715f;
+static const float GELU_QUICK_COEF = -1.702f;
+static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+inline static float ggml_gelu_f32(float x) {
+ return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+ const uint16_t * i16 = (const uint16_t *) x;
+ for (int i = 0; i < n; ++i) {
+ y[i] = ggml_table_gelu_f16[i16[i]];
+ }
+}
+
+#ifdef GGML_GELU_FP16
+inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
+ uint16_t t;
+ for (int i = 0; i < n; ++i) {
+ if (x[i] <= -10.0f) {
+ y[i] = 0.0f;
+ } else if (x[i] >= 10.0f) {
+ y[i] = x[i];
+ } else {
+ ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+ memcpy(&t, &fp16, sizeof(uint16_t));
+ y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
+ }
+ }
+}
+#else
+inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
+ for (int i = 0; i < n; ++i) {
+ y[i] = ggml_gelu_f32(x[i]);
+ }
+}
+#endif
+
+inline static float ggml_gelu_quick_f32(float x) {
+ return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
+}
+
+//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+// const uint16_t * i16 = (const uint16_t *) x;
+// for (int i = 0; i < n; ++i) {
+// y[i] = ggml_table_gelu_quick_f16[i16[i]];
+// }
+//}
+
+#ifdef GGML_GELU_QUICK_FP16
+inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
+ uint16_t t;
+ for (int i = 0; i < n; ++i) {
+ ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+ memcpy(&t, &fp16, sizeof(uint16_t));
+ y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
+ }
+}
+#else
+inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
+ for (int i = 0; i < n; ++i) {
+ y[i] = ggml_gelu_quick_f32(x[i]);
+ }
+}
+#endif
+
+// Sigmoid Linear Unit (SiLU) function
+inline static float ggml_silu_f32(float x) {
+ return x/(1.0f + expf(-x));
+}
+
+#if __FINITE_MATH_ONLY__
+#error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
+#error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461"
+#endif
+
+#if defined(__ARM_NEON) && defined(__aarch64__)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static float32x4_t ggml_v_expf(float32x4_t x) {
+ const float32x4_t r = vdupq_n_f32(0x1.8p23f);
+ const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
+ const float32x4_t n = vsubq_f32(z, r);
+ const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
+ vdupq_n_f32(0x1.7f7d1cp-20f));
+ const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
+ const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
+ const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
+ const float32x4_t u = vmulq_f32(b, b);
+ const float32x4_t j = vfmaq_f32(
+ vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
+ vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
+ vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
+ if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
+ return vfmaq_f32(k, j, k);
+ const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
+ const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
+ const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
+ return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
+ vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static float32x4_t ggml_v_silu(float32x4_t x) {
+ const float32x4_t one = vdupq_n_f32(1.0f);
+ const float32x4_t zero = vdupq_n_f32(0.0f);
+ const float32x4_t neg_x = vsubq_f32(zero, x);
+ const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
+ const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
+ return vdivq_f32(x, one_plus_exp_neg_x);
+}
+
+#elif defined(__AVX512F__) && defined(__AVX512DQ__)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static __m512 ggml_v_expf(__m512 x) {
+ const __m512 r = _mm512_set1_ps(0x1.8p23f);
+ const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
+ const __m512 n = _mm512_sub_ps(z, r);
+ const __m512 b =
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
+ const __mmask16 d =
+ _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
+ const __m512 u = _mm512_mul_ps(b, b);
+ const __m512 j = _mm512_fmadd_ps(
+ _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
+ _mm512_set1_ps(0x1.573e2ep-5f)),
+ u,
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
+ _mm512_set1_ps(0x1.fffdb6p-2f))),
+ u,
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
+ const __m512 res = _mm512_scalef_ps(j, n);
+ if (_mm512_kortestz(d, d))
+ return res;
+ const __m512 zero = _mm512_setzero_ps();
+ const __m512 alt = _mm512_mask_blend_ps(
+ _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
+ return _mm512_mask_blend_ps(d, res, alt);
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static __m512 ggml_v_silu(__m512 x) {
+ const __m512 one = _mm512_set1_ps(1);
+ const __m512 zero = _mm512_setzero_ps();
+ const __m512 neg_x = _mm512_sub_ps(zero, x);
+ const __m512 exp_neg_x = ggml_v_expf(neg_x);
+ const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
+ return _mm512_div_ps(x, one_plus_exp_neg_x);
+}
+
+#elif defined(__AVX2__) && defined(__FMA__)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static __m256 ggml_v_expf(__m256 x) {
+ const __m256 r = _mm256_set1_ps(0x1.8p23f);
+ const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
+ const __m256 n = _mm256_sub_ps(z, r);
+ const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
+ _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
+ const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
+ const __m256 k = _mm256_castsi256_ps(
+ _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
+ const __m256i c = _mm256_castps_si256(
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
+ _mm256_set1_ps(126), _CMP_GT_OQ));
+ const __m256 u = _mm256_mul_ps(b, b);
+ const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
+ _mm256_set1_ps(0x1.573e2ep-5f)), u,
+ _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
+ _mm256_set1_ps(0x1.fffdb6p-2f))),
+ u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
+ if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
+ return _mm256_fmadd_ps(j, k, k);
+ const __m256i g = _mm256_and_si256(
+ _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
+ _mm256_set1_epi32(0x82000000u));
+ const __m256 s1 =
+ _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
+ const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
+ const __m256i d = _mm256_castps_si256(
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
+ _mm256_set1_ps(192), _CMP_GT_OQ));
+ return _mm256_or_ps(
+ _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
+ _mm256_andnot_ps(
+ _mm256_castsi256_ps(d),
+ _mm256_or_ps(
+ _mm256_and_ps(_mm256_castsi256_ps(c),
+ _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
+ _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static __m256 ggml_v_silu(__m256 x) {
+ const __m256 one = _mm256_set1_ps(1);
+ const __m256 zero = _mm256_setzero_ps();
+ const __m256 neg_x = _mm256_sub_ps(zero, x);
+ const __m256 exp_neg_x = ggml_v_expf(neg_x);
+ const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
+ return _mm256_div_ps(x, one_plus_exp_neg_x);
+}
+
+#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
+
+#if defined(__FMA__)
+#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
+#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
+#else
+#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
+#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
+#endif
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static __m128 ggml_v_expf(__m128 x) {
+ const __m128 r = _mm_set1_ps(0x1.8p23f);
+ const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
+ const __m128 n = _mm_sub_ps(z, r);
+ const __m128 b =
+ NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
+ const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
+ const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
+ const __m128i c =
+ _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
+ const __m128 u = _mm_mul_ps(b, b);
+ const __m128 j =
+ MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
+ MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
+ u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
+ if (!_mm_movemask_epi8(c))
+ return MADD128(j, k, k);
+ const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
+ _mm_set1_epi32(0x82000000u));
+ const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
+ const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
+ const __m128i d =
+ _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
+ return _mm_or_ps(
+ _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
+ _mm_andnot_ps(_mm_castsi128_ps(d),
+ _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
+ _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static __m128 ggml_v_silu(__m128 x) {
+ const __m128 one = _mm_set1_ps(1);
+ const __m128 zero = _mm_setzero_ps();
+ const __m128 neg_x = _mm_sub_ps(zero, x);
+ const __m128 exp_neg_x = ggml_v_expf(neg_x);
+ const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
+ return _mm_div_ps(x, one_plus_exp_neg_x);
+}
+
+#endif // __ARM_NEON / __AVX2__ / __SSE2__
+
+static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
+ int i = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+ for (; i + 15 < n; i += 16) {
+ _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
+ }
+#elif defined(__AVX2__) && defined(__FMA__)
+ for (; i + 7 < n; i += 8) {
+ _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
+ }
+#elif defined(__SSE2__)
+ for (; i + 3 < n; i += 4) {
+ _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
+ }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+ for (; i + 3 < n; i += 4) {
+ vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
+ }
+#endif
+ for (; i < n; ++i) {
+ y[i] = ggml_silu_f32(x[i]);
+ }
+}
+
+static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
+ int i = 0;
+ ggml_float sum = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+ for (; i + 15 < n; i += 16) {
+ __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
+ _mm512_set1_ps(max)));
+ _mm512_storeu_ps(y + i, val);
+ sum += (ggml_float)_mm512_reduce_add_ps(val);
+ }
+#elif defined(__AVX2__) && defined(__FMA__)
+ for (; i + 7 < n; i += 8) {
+ __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
+ _mm256_set1_ps(max)));
+ _mm256_storeu_ps(y + i, val);
+ __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
+ _mm256_castps256_ps128(val));
+ val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
+ val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
+ sum += (ggml_float)_mm_cvtss_f32(val2);
+ }
+#elif defined(__SSE2__)
+ for (; i + 3 < n; i += 4) {
+ __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
+ _mm_set1_ps(max)));
+ _mm_storeu_ps(y + i, val);
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+ val = _mm_add_ps(val, _mm_movehl_ps(val, val));
+ val = _mm_add_ss(val, _mm_movehdup_ps(val));
+#else
+ __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
+ val = _mm_add_ps(val, tmp);
+ tmp = _mm_movehl_ps(tmp, val);
+ val = _mm_add_ss(val, tmp);
+#endif
+ sum += (ggml_float)_mm_cvtss_f32(val);
+ }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+ for (; i + 3 < n; i += 4) {
+ float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
+ vdupq_n_f32(max)));
+ vst1q_f32(y + i, val);
+ sum += (ggml_float)vaddvq_f32(val);
+ }
+#endif
+ for (; i < n; ++i) {
+ float val = expf(x[i] - max);
+ sum += (ggml_float)val;
+ y[i] = val;
+ }
+ return sum;
+}
+
+inline static float ggml_silu_backward_f32(float x, float dy) {
+ const float s = 1.0f/(1.0f + expf(-x));
+ return dy*s*(1.0f + x*(1.0f - s));
+}
+
+inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
+ for (int i = 0; i < n; ++i) {
+ dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
+ }
+}
+
+inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
+#ifndef GGML_USE_ACCELERATE
+ ggml_float sum = 0.0;
+ for (int i = 0; i < n; ++i) {
+ sum += (ggml_float)x[i];
+ }
+ *s = sum;
+#else
+ vDSP_sve(x, 1, s, n);
+#endif
+}
+
+inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
+ ggml_float sum = 0.0;
+ for (int i = 0; i < n; ++i) {
+ sum += (ggml_float)x[i];
+ }
+ *s = sum;
+}
+
+inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) {
+ float sum = 0.0f;
+ for (int i = 0; i < n; ++i) {
+ sum += GGML_FP16_TO_FP32(x[i]);
+ }
+ *s = sum;
+}
+
+inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
+ float sum = 0.0f;
+ for (int i = 0; i < n; ++i) {
+ sum += GGML_BF16_TO_FP32(x[i]);
+ }
+ *s = sum;
+}
+
+inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
+#ifndef GGML_USE_ACCELERATE
+ float max = -INFINITY;
+ for (int i = 0; i < n; ++i) {
+ max = MAX(max, x[i]);
+ }
+ *s = max;
+#else
+ vDSP_maxv(x, 1, s, n);
+#endif
+}
+
+inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
+ ggml_vec_norm_f32(n, s, x);
+ *s = 1.f/(*s);
+}
+
+inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
+ float max = -INFINITY;
+ int idx = 0;
+ for (int i = 0; i < n; ++i) {
+ max = MAX(max, x[i]);
+ if (max == x[i]) { idx = i; }
+ }
+ *s = idx;
+}
+
+//
+// data types
+//
+
+static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
+ "NONE",
+
+ "DUP",
+ "ADD",
+ "ADD1",
+ "ACC",
+ "SUB",
+ "MUL",
+ "DIV",
+ "SQR",
+ "SQRT",
+ "LOG",
+ "SUM",
+ "SUM_ROWS",
+ "MEAN",
+ "ARGMAX",
+ "REPEAT",
+ "REPEAT_BACK",
+ "CONCAT",
+ "SILU_BACK",
+ "NORM",
+ "RMS_NORM",
+ "RMS_NORM_BACK",
+ "GROUP_NORM",
+
+ "MUL_MAT",
+ "MUL_MAT_ID",
+ "OUT_PROD",
+
+ "SCALE",
+ "SET",
+ "CPY",
+ "CONT",
+ "RESHAPE",
+ "VIEW",
+ "PERMUTE",
+ "TRANSPOSE",
+ "GET_ROWS",
+ "GET_ROWS_BACK",
+ "DIAG",
+ "DIAG_MASK_INF",
+ "DIAG_MASK_ZERO",
+ "SOFT_MAX",
+ "SOFT_MAX_BACK",
+ "ROPE",
+ "ROPE_BACK",
+ "CLAMP",
+ "CONV_TRANSPOSE_1D",
+ "IM2COL",
+ "CONV_TRANSPOSE_2D",
+ "POOL_1D",
+ "POOL_2D",
+ "UPSCALE",
+ "PAD",
+ "ARANGE",
+ "TIMESTEP_EMBEDDING",
+ "ARGSORT",
+ "LEAKY_RELU",
+
+ "FLASH_ATTN_EXT",
+ "FLASH_ATTN_BACK",
+ "SSM_CONV",
+ "SSM_SCAN",
+ "WIN_PART",
+ "WIN_UNPART",
+ "GET_REL_POS",
+ "ADD_REL_POS",
+
+ "UNARY",
+
+ "MAP_UNARY",
+ "MAP_BINARY",
+
+ "MAP_CUSTOM1_F32",
+ "MAP_CUSTOM2_F32",
+ "MAP_CUSTOM3_F32",
+
+ "MAP_CUSTOM1",
+ "MAP_CUSTOM2",
+ "MAP_CUSTOM3",
+
+ "CROSS_ENTROPY_LOSS",
+ "CROSS_ENTROPY_LOSS_BACK",
+};
+
+static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
+
+static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+ "none",
+
+ "x",
+ "x+y",
+ "x+y",
+ "view(x,nb,offset)+=y->x",
+ "x-y",
+ "x*y",
+ "x/y",
+ "x^2",
+ "√x",
+ "log(x)",
+ "Σx",
+ "Σx_k",
+ "Σx/n",
+ "argmax(x)",
+ "repeat(x)",
+ "repeat_back(x)",
+ "concat(x, y)",
+ "silu_back(x)",
+ "norm(x)",
+ "rms_norm(x)",
+ "rms_norm_back(x)",
+ "group_norm(x)",
+
+ "X*Y",
+ "X[i]*Y",
+ "X*Y",
+
+ "x*v",
+ "y-\\>view(x)",
+ "x-\\>y",
+ "cont(x)",
+ "reshape(x)",
+ "view(x)",
+ "permute(x)",
+ "transpose(x)",
+ "get_rows(x)",
+ "get_rows_back(x)",
+ "diag(x)",
+ "diag_mask_inf(x)",
+ "diag_mask_zero(x)",
+ "soft_max(x)",
+ "soft_max_back(x)",
+ "rope(x)",
+ "rope_back(x)",
+ "clamp(x)",
+ "conv_transpose_1d(x)",
+ "im2col(x)",
+ "conv_transpose_2d(x)",
+ "pool_1d(x)",
+ "pool_2d(x)",
+ "upscale(x)",
+ "pad(x)",
+ "arange(start, stop, step)",
+ "timestep_embedding(timesteps, dim, max_period)",
+ "argsort(x)",
+ "leaky_relu(x)",
+
+ "flash_attn_ext(x)",
+ "flash_attn_back(x)",
+ "ssm_conv(x)",
+ "ssm_scan(x)",
+ "win_part(x)",
+ "win_unpart(x)",
+ "get_rel_pos(x)",
+ "add_rel_pos(x)",
+
+ "unary(x)",
+
+ "f(x)",
+ "f(x,y)",
+
+ "custom_f32(x)",
+ "custom_f32(x,y)",
+ "custom_f32(x,y,z)",
+
+ "custom(x)",
+ "custom(x,y)",
+ "custom(x,y,z)",
+
+ "cross_entropy_loss(x,y)",
+ "cross_entropy_loss_back(x,y)",
+};
+
+static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
+
+static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
+
+
+static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
+ "ABS",
+ "SGN",
+ "NEG",
+ "STEP",
+ "TANH",
+ "ELU",
+ "RELU",
+ "SIGMOID",
+ "GELU",
+ "GELU_QUICK",
+ "SILU",
+ "HARDSWISH",
+ "HARDSIGMOID",
+};
+
+static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
+
+
+static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
+static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
+
+//
+// NUMA support
+//
+
+#define GGML_NUMA_MAX_NODES 8
+#define GGML_NUMA_MAX_CPUS 512
+
+struct ggml_numa_node {
+ uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node
+ uint32_t n_cpus;
+};
+
+struct ggml_numa_nodes {
+ enum ggml_numa_strategy numa_strategy;
+ struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES];
+ uint32_t n_nodes;
+ uint32_t total_cpus; // hardware threads on system
+ uint32_t current_node; // node on which main process is execting
+#if defined(__gnu_linux__)
+ cpu_set_t cpuset; // cpuset from numactl
+#else
+ uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype
+#endif
+};
+
+//
+// ggml state
+//
+
+struct ggml_state {
+ struct ggml_context_container contexts[GGML_MAX_CONTEXTS];
+ struct ggml_numa_nodes numa;
+};
+
+// global state
+static struct ggml_state g_state;
+static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
+
+// critical section via spin lock
+inline static void ggml_critical_section_start(void) {
+ while (atomic_flag_test_and_set(&g_state_critical)) {
+ // spin
+ sched_yield();
+ }
+}
+
+#ifdef GGML_USE_OPENMP
+static void ggml_barrier(struct ggml_compute_state_shared * shared) {
+ if (shared->n_threads == 1) {
+ return;
+ }
+
+ #pragma omp barrier
+}
+#else
+static void ggml_barrier(struct ggml_compute_state_shared * shared) {
+ if (shared->n_threads == 1) {
+ return;
+ }
+
+ atomic_int * n_barrier = &shared->n_barrier;
+ atomic_int * n_barrier_passed = &shared->n_barrier_passed;
+
+ int n_threads = shared->n_threads;
+ int passed_old = atomic_load(n_barrier_passed);
+
+ if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) {
+ // last thread
+ atomic_store(n_barrier, 0);
+ atomic_fetch_add(n_barrier_passed, 1);
+ } else {
+ // wait for other threads
+ const int n_spin_before_sleep = 100000;
+ while (true) {
+ for (int i = 0; i < n_spin_before_sleep; i++) {
+ if (atomic_load(n_barrier_passed) != passed_old) {
+ return;
+ }
+ #if defined(__SSE3__)
+ _mm_pause();
+ #elif defined __ARM_NEON
+ __asm__ __volatile__("isb\n");
+ #endif
+ }
+ sched_yield();
+ }
+ }
+}
+#endif
+
+// TODO: make this somehow automatically executed
+// some sort of "sentry" mechanism
+inline static void ggml_critical_section_end(void) {
+ atomic_flag_clear(&g_state_critical);
+}
+
+#if defined(__gnu_linux__)
+static cpu_set_t ggml_get_numa_affinity(void) {
+ cpu_set_t cpuset;
+ pthread_t thread;
+ thread = pthread_self();
+ CPU_ZERO(&cpuset);
+ pthread_getaffinity_np(thread, sizeof(cpu_set_t), &cpuset);
+ return cpuset;
+}
+#else
+static uint32_t ggml_get_numa_affinity(void) {
+ return 0; // no NUMA support
+}
+#endif
+
+void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
+ if (g_state.numa.n_nodes > 0) {
+ fprintf(stderr, "ggml_numa_init: NUMA already initialized\n");
+
+ return;
+ }
+
+#if defined(__gnu_linux__)
+ struct stat st;
+ char path[256];
+ int rv;
+
+ // set numa scheme
+ g_state.numa.numa_strategy = numa_flag;
+
+ GGML_PRINT_DEBUG("numa strategy %u\n",g_state.numa.numa_strategy);
+
+ g_state.numa.cpuset = ggml_get_numa_affinity();
+
+ // enumerate nodes
+ while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) {
+ rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes);
+ GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
+ if (stat(path, &st) != 0) { break; }
+ ++g_state.numa.n_nodes;
+ }
+
+ // enumerate CPUs
+ while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) {
+ rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus);
+ GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
+ if (stat(path, &st) != 0) { break; }
+ ++g_state.numa.total_cpus;
+ }
+
+ GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus);
+
+ // figure out which node we're on
+ uint current_cpu;
+ int getcpu_ret = 0;
+#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__)
+ getcpu_ret = getcpu(&current_cpu, &g_state.numa.current_node);
+#else
+ // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
+# if !defined(SYS_getcpu) && defined(SYS_get_cpu)
+# define SYS_getcpu SYS_get_cpu // some older glibc versions use this name
+# endif
+ getcpu_ret = syscall(SYS_getcpu, &current_cpu, &g_state.numa.current_node);
+#endif
+
+ if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) {
+ g_state.numa.n_nodes = 0;
+ return;
+ }
+
+ GGML_PRINT_DEBUG("found our process on numa node %u, CPU %u\n", g_state.numa.current_node, current_cpu);
+
+ for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) {
+ struct ggml_numa_node * node = &g_state.numa.nodes[n];
+ GGML_PRINT_DEBUG("CPUs on node %u:", n);
+ node->n_cpus = 0;
+ for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) {
+ rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c);
+ GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
+ if (stat(path, &st) == 0) {
+ node->cpus[node->n_cpus++] = c;
+ GGML_PRINT_DEBUG(" %u", c);
+ }
+ }
+ GGML_PRINT_DEBUG("\n");
+ }
+
+ if (ggml_is_numa()) {
+ FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r");
+ if (fptr != NULL) {
+ char buf[42];
+ if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) {
+ GGML_PRINT("WARNING: /proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n");
+ }
+ fclose(fptr);
+ }
+ }
+#else
+ UNUSED(numa_flag);
+ // TODO
+#endif
+}
+
+bool ggml_is_numa(void) {
+ return g_state.numa.n_nodes > 1;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_print_object(const struct ggml_object * obj) {
+ GGML_PRINT(" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n",
+ obj->type, obj->offs, obj->size, (const void *) obj->next);
+}
+
+void ggml_print_objects(const struct ggml_context * ctx) {
+ struct ggml_object * obj = ctx->objects_begin;
+
+ GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx);
+
+ while (obj != NULL) {
+ ggml_print_object(obj);
+ obj = obj->next;
+ }
+
+ GGML_PRINT("%s: --- end ---\n", __func__);
+}
+
+GGML_CALL int64_t ggml_nelements(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
+}
+
+GGML_CALL int64_t ggml_nrows(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
+}
+
+GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) {
+ size_t nbytes;
+ size_t blck_size = ggml_blck_size(tensor->type);
+ if (blck_size == 1) {
+ nbytes = ggml_type_size(tensor->type);
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+ nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
+ }
+ }
+ else {
+ nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
+ for (int i = 1; i < GGML_MAX_DIMS; ++i) {
+ nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
+ }
+ }
+
+ return nbytes;
+}
+
+size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
+ return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN);
+}
+
+GGML_CALL int64_t ggml_blck_size(enum ggml_type type) {
+ return type_traits[type].blck_size;
+}
+
+GGML_CALL size_t ggml_type_size(enum ggml_type type) {
+ return type_traits[type].type_size;
+}
+
+GGML_CALL size_t ggml_row_size(enum ggml_type type, int64_t ne) {
+ assert(ne % ggml_blck_size(type) == 0);
+ return ggml_type_size(type)*ne/ggml_blck_size(type);
+}
+
+double ggml_type_sizef(enum ggml_type type) {
+ return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;
+}
+
+GGML_CALL const char * ggml_type_name(enum ggml_type type) {
+ return type_traits[type].type_name;
+}
+
+GGML_CALL bool ggml_is_quantized(enum ggml_type type) {
+ return type_traits[type].is_quantized;
+}
+
+GGML_CALL const char * ggml_op_name(enum ggml_op op) {
+ return GGML_OP_NAME[op];
+}
+
+const char * ggml_op_symbol(enum ggml_op op) {
+ return GGML_OP_SYMBOL[op];
+}
+
+const char * ggml_unary_op_name(enum ggml_unary_op op) {
+ return GGML_UNARY_OP_NAME[op];
+}
+
+GGML_CALL const char * ggml_op_desc(const struct ggml_tensor * t) {
+ if (t->op == GGML_OP_UNARY) {
+ enum ggml_unary_op uop = ggml_get_unary_op(t);
+ return ggml_unary_op_name(uop);
+ }
+ return ggml_op_name(t->op);
+}
+
+GGML_CALL size_t ggml_element_size(const struct ggml_tensor * tensor) {
+ return ggml_type_size(tensor->type);
+}
+
+bool ggml_is_scalar(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+bool ggml_is_vector(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+bool ggml_is_matrix(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+bool ggml_is_3d(const struct ggml_tensor * tensor) {
+ return tensor->ne[3] == 1;
+}
+
+int ggml_n_dims(const struct ggml_tensor * tensor) {
+ for (int i = GGML_MAX_DIMS - 1; i >= 1; --i) {
+ if (tensor->ne[i] > 1) {
+ return i + 1;
+ }
+ }
+ return 1;
+}
+
+static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return (t0->ne[0] == t1->ne[0]) &&
+ (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
+ (t1->ne[3]%t0->ne[3] == 0);
+}
+
+static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return (t0->ne[1] == t1->ne[1]) &&
+ (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
+ (t1->ne[3]%t0->ne[3] == 0);
+}
+
+enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
+ enum ggml_type wtype = GGML_TYPE_COUNT;
+
+ switch (ftype) {
+ case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
+ case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
+ case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break;
+ case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
+ case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
+ case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
+ case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
+ case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
+ case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
+ case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
+ case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
+ case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
+ case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
+ case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
+ case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break;
+ case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break;
+ case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break;
+ case GGML_FTYPE_MOSTLY_IQ1_M: wtype = GGML_TYPE_IQ1_M; break;
+ case GGML_FTYPE_MOSTLY_IQ1_BN: wtype = GGML_TYPE_IQ1_BN; break;
+ case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break;
+ case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break;
+ case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
+ case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
+ case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
+ case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break;
+ case GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = GGML_TYPE_Q4_0_4_8; break;
+ case GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = GGML_TYPE_Q4_0_8_8; break;
+ case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
+ case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
+ }
+
+ GGML_ASSERT(wtype != GGML_TYPE_COUNT);
+
+ return wtype;
+}
+
+size_t ggml_tensor_overhead(void) {
+ return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE;
+}
+
+GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor) {
+ return tensor->nb[0] > tensor->nb[1];
+}
+
+static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) {
+ size_t next_nb = ggml_type_size(tensor->type);
+ if (tensor->ne[0] != ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) {
+ return false;
+ }
+ next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type);
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ if (tensor->ne[i] != 1) {
+ if (i > n) {
+ if (tensor->nb[i] != next_nb) {
+ return false;
+ }
+ next_nb *= tensor->ne[i];
+ } else {
+ // this dimension does not need to be contiguous
+ next_nb = tensor->ne[i]*tensor->nb[i];
+ }
+ }
+ }
+ return true;
+}
+
+GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
+ return ggml_is_contiguous_0(tensor);
+}
+
+GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
+ return ggml_is_contiguous_n(tensor, 0);
+}
+
+GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
+ return ggml_is_contiguous_n(tensor, 1);
+}
+
+GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
+ return ggml_is_contiguous_n(tensor, 2);
+}
+
+GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
+}
+
+static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ tensor->nb[0] == ggml_type_size(tensor->type) &&
+ tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
+ tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
+}
+
+GGML_CALL bool ggml_is_empty(const struct ggml_tensor * tensor) {
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+ if (tensor->ne[i] == 0) {
+ // empty if any dimension has no elements
+ return true;
+ }
+ }
+ return false;
+}
+
+bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ (t0->ne[0] == t1->ne[0]) &&
+ (t0->ne[1] == t1->ne[1]) &&
+ (t0->ne[2] == t1->ne[2]) &&
+ (t0->ne[3] == t1->ne[3]);
+}
+
+bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ (t0->nb[0] == t1->nb[0]) &&
+ (t0->nb[1] == t1->nb[1]) &&
+ (t0->nb[2] == t1->nb[2]) &&
+ (t0->nb[3] == t1->nb[3]);
+}
+
+// check if t1 can be represented as a repeatition of t0
+bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return ggml_is_empty(t0) ? ggml_is_empty(t1) :
+ (t1->ne[0]%t0->ne[0] == 0) &&
+ (t1->ne[1]%t0->ne[1] == 0) &&
+ (t1->ne[2]%t0->ne[2] == 0) &&
+ (t1->ne[3]%t0->ne[3] == 0);
+}
+
+static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1);
+}
+
+static inline int ggml_up32(int n) {
+ return (n + 31) & ~31;
+}
+
+//static inline int ggml_up64(int n) {
+// return (n + 63) & ~63;
+//}
+
+static inline int ggml_up(int n, int m) {
+ // assert m is a power of 2
+ GGML_ASSERT((m & (m - 1)) == 0);
+ return (n + m - 1) & ~(m - 1);
+}
+
+// assert that pointer is aligned to GGML_MEM_ALIGN
+#define ggml_assert_aligned(ptr) \
+ GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct ggml_context * ggml_init(struct ggml_init_params params) {
+ // make this function thread safe
+ ggml_critical_section_start();
+
+ static bool is_first_call = true;
+
+ if (is_first_call) {
+ // initialize time system (required on Windows)
+ ggml_time_init();
+
+ // initialize GELU, Quick GELU, SILU and EXP F32 tables
+ {
+ const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
+
+ for (int i = 0; i < (1 << 16); ++i) {
+ union {
+ uint16_t u16;
+ ggml_fp16_t fp16;
+ } u = {i};
+ float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
+ ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
+ ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
+ }
+
+ const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
+
+ GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
+ }
+
+ // initialize g_state
+ {
+ const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
+
+ g_state = (struct ggml_state) {
+ /*.contexts =*/ { { 0 } },
+ /*.numa =*/ {
+ .n_nodes = 0,
+ .total_cpus = 0,
+ },
+ };
+
+ for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) {
+ g_state.contexts[i].used = false;
+ }
+
+ const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
+
+ GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
+ }
+
+ is_first_call = false;
+ }
+
+ // find non-used context in g_state
+ struct ggml_context * ctx = NULL;
+
+ for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
+ if (!g_state.contexts[i].used) {
+ g_state.contexts[i].used = true;
+ ctx = &g_state.contexts[i].context;
+
+ GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i);
+ break;
+ }
+ }
+
+ if (ctx == NULL) {
+ GGML_PRINT_DEBUG("%s: no unused context found\n", __func__);
+
+ ggml_critical_section_end();
+
+ return NULL;
+ }
+
+ // allow to call ggml_init with 0 size
+ if (params.mem_size == 0) {
+ params.mem_size = GGML_MEM_ALIGN;
+ }
+
+ const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN);
+
+ *ctx = (struct ggml_context) {
+ /*.mem_size =*/ mem_size,
+ /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
+ /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
+ /*.no_alloc =*/ params.no_alloc,
+ /*.no_alloc_save =*/ params.no_alloc,
+ /*.n_objects =*/ 0,
+ /*.objects_begin =*/ NULL,
+ /*.objects_end =*/ NULL,
+ /*.scratch =*/ { 0, 0, NULL, },
+ /*.scratch_save =*/ { 0, 0, NULL, },
+ };
+
+ GGML_ASSERT(ctx->mem_buffer != NULL);
+
+ ggml_assert_aligned(ctx->mem_buffer);
+
+ GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
+
+ ggml_critical_section_end();
+
+ return ctx;
+}
+
+void ggml_free(struct ggml_context * ctx) {
+ if (ctx == NULL) {
+ return;
+ }
+
+ // make this function thread safe
+ ggml_critical_section_start();
+
+ bool found = false;
+
+ for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
+ if (&g_state.contexts[i].context == ctx) {
+ g_state.contexts[i].used = false;
+
+ GGML_PRINT_DEBUG("%s: context %d has been freed. memory used = %zu\n",
+ __func__, i, ggml_used_mem(ctx));
+
+ if (ctx->mem_buffer_owned) {
+ GGML_ALIGNED_FREE(ctx->mem_buffer);
+ }
+
+ found = true;
+ break;
+ }
+ }
+
+ if (!found) {
+ GGML_PRINT_DEBUG("%s: context not found\n", __func__);
+ }
+
+ ggml_critical_section_end();
+}
+
+size_t ggml_used_mem(const struct ggml_context * ctx) {
+ return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size;
+}
+
+size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) {
+ const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0;
+
+ ctx->scratch = scratch;
+
+ return result;
+}
+
+bool ggml_get_no_alloc(struct ggml_context * ctx) {
+ return ctx->no_alloc;
+}
+
+void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) {
+ ctx->no_alloc = no_alloc;
+}
+
+void * ggml_get_mem_buffer(const struct ggml_context * ctx) {
+ return ctx->mem_buffer;
+}
+
+size_t ggml_get_mem_size(const struct ggml_context * ctx) {
+ return ctx->mem_size;
+}
+
+size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
+ size_t max_size = 0;
+
+ for (struct ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor != NULL; tensor = ggml_get_next_tensor(ctx, tensor)) {
+ size_t bytes = ggml_nbytes(tensor);
+ max_size = MAX(max_size, bytes);
+ }
+
+ return max_size;
+}
+
+// IMPORTANT:
+// when creating "opt" tensors, always save and load the scratch buffer
+// this is an error prone process, but it is necessary to support inplace
+// operators when using scratch buffers
+// TODO: implement a better way
+static void ggml_scratch_save(struct ggml_context * ctx) {
+ // this is needed to allow opt tensors to store their data
+ // TODO: again, need to find a better way
+ ctx->no_alloc_save = ctx->no_alloc;
+ ctx->no_alloc = false;
+
+ ctx->scratch_save = ctx->scratch;
+ ctx->scratch.data = NULL;
+}
+
+static void ggml_scratch_load(struct ggml_context * ctx) {
+ ctx->no_alloc = ctx->no_alloc_save;
+
+ ctx->scratch = ctx->scratch_save;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) {
+ // always insert objects at the end of the context's memory pool
+ struct ggml_object * obj_cur = ctx->objects_end;
+
+ const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs;
+ const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
+ const size_t cur_end = cur_offs + cur_size;
+
+ // align to GGML_MEM_ALIGN
+ size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN);
+
+ char * const mem_buffer = ctx->mem_buffer;
+ struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
+
+ if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
+ GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
+ __func__, cur_end + size_needed, ctx->mem_size);
+ assert(false);
+ return NULL;
+ }
+
+ *obj_new = (struct ggml_object) {
+ .offs = cur_end + GGML_OBJECT_SIZE,
+ .size = size_needed,
+ .next = NULL,
+ .type = type,
+ };
+
+ ggml_assert_aligned(mem_buffer + obj_new->offs);
+
+ if (obj_cur != NULL) {
+ obj_cur->next = obj_new;
+ } else {
+ // this is the first object in this context
+ ctx->objects_begin = obj_new;
+ }
+
+ ctx->objects_end = obj_new;
+
+ //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);
+
+ return obj_new;
+}
+
+static struct ggml_tensor * ggml_new_tensor_impl(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int n_dims,
+ const int64_t * ne,
+ struct ggml_tensor * view_src,
+ size_t view_offs) {
+
+ assert(n_dims >= 1 && n_dims <= GGML_MAX_DIMS);
+
+ // find the base tensor and absolute offset
+ if (view_src != NULL && view_src->view_src != NULL) {
+ view_offs += view_src->view_offs;
+ view_src = view_src->view_src;
+ }
+
+ size_t data_size = ggml_row_size(type, ne[0]);
+ for (int i = 1; i < n_dims; i++) {
+ data_size *= ne[i];
+ }
+
+ GGML_ASSERT(view_src == NULL || data_size == 0 || data_size + view_offs <= ggml_nbytes(view_src));
+
+ void * data = view_src != NULL ? view_src->data : NULL;
+ if (data != NULL) {
+ data = (char *) data + view_offs;
+ }
+
+ size_t obj_alloc_size = 0;
+
+ if (view_src == NULL && !ctx->no_alloc) {
+ if (ctx->scratch.data != NULL) {
+ // allocate tensor data in the scratch buffer
+ if (ctx->scratch.offs + data_size > ctx->scratch.size) {
+ GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
+ __func__, ctx->scratch.offs + data_size, ctx->scratch.size);
+ assert(false);
+ return NULL;
+ }
+
+ data = (char * const) ctx->scratch.data + ctx->scratch.offs;
+
+ ctx->scratch.offs += data_size;
+ } else {
+ // allocate tensor data in the context's memory pool
+ obj_alloc_size = data_size;
+ }
+ }
+
+ struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size);
+
+ // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here
+
+ struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
+
+#ifdef __clang__
+ // temporary until ggml_tensor::backend is removed
+ #pragma clang diagnostic push
+ #pragma clang diagnostic ignored "-Wdeprecated-declarations"
+#endif
+
+ *result = (struct ggml_tensor) {
+ /*.type =*/ type,
+ /*.backend =*/ GGML_BACKEND_TYPE_CPU,
+ /*.buffer =*/ NULL,
+ /*.ne =*/ { 1, 1, 1, 1 },
+ /*.nb =*/ { 0, 0, 0, 0 },
+ /*.op =*/ GGML_OP_NONE,
+ /*.op_params =*/ { 0 },
+ /*.flags =*/ 0,
+ /*.grad =*/ NULL,
+ /*.src =*/ { NULL },
+ /*.view_src =*/ view_src,
+ /*.view_offs =*/ view_offs,
+ /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
+ /*.name =*/ { 0 },
+ /*.extra =*/ NULL,
+ ///*.padding =*/ { 0 },
+ };
+
+#ifdef __clang__
+ #pragma clang diagnostic pop
+#endif
+
+ // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
+ //ggml_assert_aligned(result->data);
+
+ for (int i = 0; i < n_dims; i++) {
+ result->ne[i] = ne[i];
+ }
+
+ result->nb[0] = ggml_type_size(type);
+ result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type));
+ for (int i = 2; i < GGML_MAX_DIMS; i++) {
+ result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
+ }
+
+ ctx->n_objects++;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_new_tensor(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int n_dims,
+ const int64_t * ne) {
+ return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL, 0);
+}
+
+struct ggml_tensor * ggml_new_tensor_1d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int64_t ne0) {
+ return ggml_new_tensor(ctx, type, 1, &ne0);
+}
+
+struct ggml_tensor * ggml_new_tensor_2d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int64_t ne0,
+ int64_t ne1) {
+ const int64_t ne[2] = { ne0, ne1 };
+ return ggml_new_tensor(ctx, type, 2, ne);
+}
+
+struct ggml_tensor * ggml_new_tensor_3d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2) {
+ const int64_t ne[3] = { ne0, ne1, ne2 };
+ return ggml_new_tensor(ctx, type, 3, ne);
+}
+
+struct ggml_tensor * ggml_new_tensor_4d(
+ struct ggml_context * ctx,
+ enum ggml_type type,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2,
+ int64_t ne3) {
+ const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
+ return ggml_new_tensor(ctx, type, 4, ne);
+}
+
+struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
+ ggml_scratch_save(ctx);
+
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
+
+ ggml_scratch_load(ctx);
+
+ ggml_set_i32(result, value);
+
+ return result;
+}
+
+struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
+ ggml_scratch_save(ctx);
+
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+
+ ggml_scratch_load(ctx);
+
+ ggml_set_f32(result, value);
+
+ return result;
+}
+
+struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) {
+ return ggml_new_tensor(ctx, src->type, GGML_MAX_DIMS, src->ne);
+}
+
+static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) {
+ GGML_ASSERT(tensor != NULL); // silence -Warray-bounds warnings
+ assert(params_size <= GGML_MAX_OP_PARAMS);
+ memcpy(tensor->op_params, params, params_size);
+}
+
+static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) {
+ assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
+ return ((const int32_t *)(tensor->op_params))[i];
+}
+
+static float ggml_get_op_params_f32(const struct ggml_tensor * tensor, uint32_t i) {
+ assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
+ return ((const float *)(tensor->op_params))[i];
+}
+
+static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) {
+ assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
+ ((int32_t *)(tensor->op_params))[i] = value;
+}
+
+static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, float value) {
+ assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
+ ((float *)(tensor->op_params))[i] = value;
+}
+
+struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
+ memset(tensor->data, 0, ggml_nbytes(tensor));
+ return tensor;
+}
+
+struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
+ const int n = ggml_nrows(tensor);
+ const int nc = tensor->ne[0];
+ const size_t n1 = tensor->nb[1];
+
+ char * const data = tensor->data;
+
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ assert(tensor->nb[0] == sizeof(int8_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_I16:
+ {
+ assert(tensor->nb[0] == sizeof(int16_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_I32:
+ {
+ assert(tensor->nb[0] == sizeof(int32_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_F16:
+ {
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
+ }
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
+ }
+ } break;
+ case GGML_TYPE_F32:
+ {
+ assert(tensor->nb[0] == sizeof(float));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
+ }
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+
+ return tensor;
+}
+
+struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
+ const int n = ggml_nrows(tensor);
+ const int nc = tensor->ne[0];
+ const size_t n1 = tensor->nb[1];
+
+ char * const data = tensor->data;
+
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ assert(tensor->nb[0] == sizeof(int8_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_I16:
+ {
+ assert(tensor->nb[0] == sizeof(int16_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_I32:
+ {
+ assert(tensor->nb[0] == sizeof(int32_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
+ }
+ } break;
+ case GGML_TYPE_F16:
+ {
+ assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
+ }
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ assert(tensor->nb[0] == sizeof(ggml_bf16_t));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
+ }
+ } break;
+ case GGML_TYPE_F32:
+ {
+ assert(tensor->nb[0] == sizeof(float));
+ for (int i = 0; i < n; i++) {
+ ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
+ }
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+
+ return tensor;
+}
+
+void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {
+ const int64_t ne2 = tensor->ne[2];
+ const int64_t ne1 = tensor->ne[1];
+ const int64_t ne0 = tensor->ne[0];
+
+ const int64_t i3_ = (i/(ne2*ne1*ne0));
+ const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);
+ const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;
+ const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);
+
+ if (i0) {
+ * i0 = i0_;
+ }
+ if (i1) {
+ * i1 = i1_;
+ }
+ if (i2) {
+ * i2 = i2_;
+ }
+ if (i3) {
+ * i3 = i3_;
+ }
+}
+
+int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
+ if (!ggml_is_contiguous(tensor)) {
+ int64_t id[4] = { 0, 0, 0, 0 };
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+ return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
+ }
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+ return ((int8_t *)(tensor->data))[i];
+ }
+ case GGML_TYPE_I16:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+ return ((int16_t *)(tensor->data))[i];
+ }
+ case GGML_TYPE_I32:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+ return ((int32_t *)(tensor->data))[i];
+ }
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
+ }
+ case GGML_TYPE_BF16:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
+ }
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(float));
+ return ((float *)(tensor->data))[i];
+ }
+ default:
+ {
+ GGML_ASSERT(false);
+ }
+ }
+
+ return 0.0f;
+}
+
+void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
+ if (!ggml_is_contiguous(tensor)) {
+ int64_t id[4] = { 0, 0, 0, 0 };
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+ ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
+ return;
+ }
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+ ((int8_t *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_I16:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+ ((int16_t *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_I32:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+ ((int32_t *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+ ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(tensor->nb[0] == sizeof(float));
+ ((float *)(tensor->data))[i] = value;
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ return ((int8_t *) data)[0];
+ case GGML_TYPE_I16:
+ return ((int16_t *) data)[0];
+ case GGML_TYPE_I32:
+ return ((int32_t *) data)[0];
+ case GGML_TYPE_F16:
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
+ case GGML_TYPE_BF16:
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
+ case GGML_TYPE_F32:
+ return ((float *) data)[0];
+ default:
+ GGML_ASSERT(false);
+ }
+
+ return 0.0f;
+}
+
+void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ ((int8_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_I16:
+ {
+ ((int16_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_I32:
+ {
+ ((int32_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ((float *)(data))[0] = value;
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
+ if (!ggml_is_contiguous(tensor)) {
+ int64_t id[4] = { 0, 0, 0, 0 };
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+ return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
+ }
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ return ((int8_t *)(tensor->data))[i];
+ }
+ case GGML_TYPE_I16:
+ {
+ return ((int16_t *)(tensor->data))[i];
+ }
+ case GGML_TYPE_I32:
+ {
+ return ((int32_t *)(tensor->data))[i];
+ }
+ case GGML_TYPE_F16:
+ {
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
+ }
+ case GGML_TYPE_BF16:
+ {
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
+ }
+ case GGML_TYPE_F32:
+ {
+ return ((float *)(tensor->data))[i];
+ }
+ default:
+ {
+ GGML_ASSERT(false);
+ }
+ }
+
+ return 0.0f;
+}
+
+void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
+ if (!ggml_is_contiguous(tensor)) {
+ int64_t id[4] = { 0, 0, 0, 0 };
+ ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+ ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
+ return;
+ }
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ ((int8_t *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_I16:
+ {
+ ((int16_t *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_I32:
+ {
+ ((int32_t *)(tensor->data))[i] = value;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ((float *)(tensor->data))[i] = value;
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ return ((int8_t *) data)[0];
+ case GGML_TYPE_I16:
+ return ((int16_t *) data)[0];
+ case GGML_TYPE_I32:
+ return ((int32_t *) data)[0];
+ case GGML_TYPE_F16:
+ return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
+ case GGML_TYPE_BF16:
+ return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
+ case GGML_TYPE_F32:
+ return ((float *) data)[0];
+ default:
+ GGML_ASSERT(false);
+ }
+
+ return 0.0f;
+}
+
+void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
+ void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+ switch (tensor->type) {
+ case GGML_TYPE_I8:
+ {
+ ((int8_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_I16:
+ {
+ ((int16_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_I32:
+ {
+ ((int32_t *)(data))[0] = value;
+ } break;
+ case GGML_TYPE_F16:
+ {
+ ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ((float *)(data))[0] = value;
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+void * ggml_get_data(const struct ggml_tensor * tensor) {
+ return tensor->data;
+}
+
+float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
+ assert(tensor->type == GGML_TYPE_F32);
+ return (float *)(tensor->data);
+}
+
+GGML_CALL enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
+ GGML_ASSERT(tensor->op == GGML_OP_UNARY);
+ return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
+}
+
+const char * ggml_get_name(const struct ggml_tensor * tensor) {
+ return tensor->name;
+}
+
+struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) {
+ strncpy(tensor->name, name, sizeof(tensor->name) - 1);
+ tensor->name[sizeof(tensor->name) - 1] = '\0';
+ return tensor;
+}
+
+struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...) {
+ va_list args;
+ va_start(args, fmt);
+ vsnprintf(tensor->name, sizeof(tensor->name), fmt, args);
+ va_end(args);
+ return tensor;
+}
+
+struct ggml_tensor * ggml_view_tensor(
+ struct ggml_context * ctx,
+ struct ggml_tensor * src) {
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, GGML_MAX_DIMS, src->ne, src, 0);
+ ggml_format_name(result, "%s (view)", src->name);
+
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ result->nb[i] = src->nb[i];
+ }
+
+ return result;
+}
+
+struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx) {
+ struct ggml_object * obj = ctx->objects_begin;
+
+ char * const mem_buffer = ctx->mem_buffer;
+
+ while (obj != NULL) {
+ if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
+ return (struct ggml_tensor *)(mem_buffer + obj->offs);
+ }
+
+ obj = obj->next;
+ }
+
+ return NULL;
+}
+
+struct ggml_tensor * ggml_get_next_tensor(const struct ggml_context * ctx, struct ggml_tensor * tensor) {
+ struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE);
+ obj = obj->next;
+
+ char * const mem_buffer = ctx->mem_buffer;
+
+ while (obj != NULL) {
+ if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
+ return (struct ggml_tensor *)(mem_buffer + obj->offs);
+ }
+
+ obj = obj->next;
+ }
+
+ return NULL;
+}
+
+struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) {
+ struct ggml_object * obj = ctx->objects_begin;
+
+ char * const mem_buffer = ctx->mem_buffer;
+
+ while (obj != NULL) {
+ if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
+ struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs);
+ if (strcmp(cur->name, name) == 0) {
+ return cur;
+ }
+ }
+
+ obj = obj->next;
+ }
+
+ return NULL;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// ggml_dup
+
+static struct ggml_tensor * ggml_dup_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_DUP;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_dup(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_dup_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_dup_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_dup_impl(ctx, a, true);
+}
+
+// ggml_add
+
+static struct ggml_tensor * ggml_add_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ GGML_ASSERT(ggml_can_repeat(b, a));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ // TODO: support backward pass for broadcasting
+ GGML_ASSERT(ggml_are_same_shape(a, b));
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_ADD;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_add(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_add_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_add_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_add_impl(ctx, a, b, true);
+}
+
+// ggml_add_cast
+
+static struct ggml_tensor * ggml_add_cast_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_type type) {
+ // TODO: support less-strict constraint
+ // GGML_ASSERT(ggml_can_repeat(b, a));
+ GGML_ASSERT(ggml_can_repeat_rows(b, a));
+
+ // currently only supported for quantized input and f16
+ GGML_ASSERT(ggml_is_quantized(a->type) ||
+ a->type == GGML_TYPE_F16 ||
+ a->type == GGML_TYPE_BF16);
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ // TODO: support backward pass for broadcasting
+ GGML_ASSERT(ggml_are_same_shape(a, b));
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
+
+ result->op = GGML_OP_ADD;
+ result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_add_cast(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ enum ggml_type type) {
+ return ggml_add_cast_impl(ctx, a, b, type);
+}
+
+// ggml_add1
+
+static struct ggml_tensor * ggml_add1_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ GGML_ASSERT(ggml_is_scalar(b));
+ GGML_ASSERT(ggml_is_padded_1d(a));
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_ADD1;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_add1(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_add1_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_add1_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_add1_impl(ctx, a, b, true);
+}
+
+// ggml_acc
+
+static struct ggml_tensor * ggml_acc_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t nb1,
+ size_t nb2,
+ size_t nb3,
+ size_t offset,
+ bool inplace) {
+ GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a));
+ GGML_ASSERT(ggml_is_contiguous(a));
+ GGML_ASSERT(a->type == GGML_TYPE_F32);
+ GGML_ASSERT(b->type == GGML_TYPE_F32);
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_ACC;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_acc(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t nb1,
+ size_t nb2,
+ size_t nb3,
+ size_t offset) {
+ return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
+}
+
+struct ggml_tensor * ggml_acc_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t nb1,
+ size_t nb2,
+ size_t nb3,
+ size_t offset) {
+ return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
+}
+
+// ggml_sub
+
+static struct ggml_tensor * ggml_sub_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ GGML_ASSERT(ggml_are_same_shape(a, b));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_SUB;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_sub(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_sub_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_sub_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_sub_impl(ctx, a, b, true);
+}
+
+// ggml_mul
+
+static struct ggml_tensor * ggml_mul_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ GGML_ASSERT(ggml_can_repeat(b, a));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ // TODO: support backward pass for broadcasting
+ GGML_ASSERT(ggml_are_same_shape(a, b));
+ is_node = true;
+ }
+
+ if (inplace) {
+ GGML_ASSERT(!is_node);
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_MUL;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_mul(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_mul_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_mul_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_mul_impl(ctx, a, b, true);
+}
+
+// ggml_div
+
+static struct ggml_tensor * ggml_div_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ GGML_ASSERT(ggml_can_repeat(b, a));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ if (inplace) {
+ GGML_ASSERT(!is_node);
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_DIV;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_div(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_div_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_div_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_div_impl(ctx, a, b, true);
+}
+
+// ggml_sqr
+
+static struct ggml_tensor * ggml_sqr_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_SQR;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_sqr(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sqr_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sqr_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sqr_impl(ctx, a, true);
+}
+
+// ggml_sqrt
+
+static struct ggml_tensor * ggml_sqrt_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_SQRT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_sqrt(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sqrt_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sqrt_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sqrt_impl(ctx, a, true);
+}
+
+// ggml_log
+
+static struct ggml_tensor * ggml_log_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_LOG;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_log(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_log_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_log_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_log_impl(ctx, a, true);
+}
+
+// ggml_sum
+
+struct ggml_tensor * ggml_sum(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
+
+ result->op = GGML_OP_SUM;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_sum_rows
+
+struct ggml_tensor * ggml_sum_rows(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ int64_t ne[GGML_MAX_DIMS] = { 1 };
+ for (int i = 1; i < GGML_MAX_DIMS; ++i) {
+ ne[i] = a->ne[i];
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
+
+ result->op = GGML_OP_SUM_ROWS;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_mean
+
+struct ggml_tensor * ggml_mean(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ bool is_node = false;
+
+ if (a->grad) {
+ GGML_ASSERT(false); // TODO: implement
+ is_node = true;
+ }
+
+ int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ result->op = GGML_OP_MEAN;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_argmax
+
+struct ggml_tensor * ggml_argmax(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ GGML_ASSERT(ggml_is_matrix(a));
+ bool is_node = false;
+
+ if (a->grad) {
+ GGML_ASSERT(false);
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
+
+ result->op = GGML_OP_ARGMAX;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_repeat
+
+struct ggml_tensor * ggml_repeat(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ GGML_ASSERT(ggml_can_repeat(a, b));
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);
+
+ result->op = GGML_OP_REPEAT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_repeat_back
+
+struct ggml_tensor * ggml_repeat_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ GGML_ASSERT(ggml_can_repeat(b, a));
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ if (ggml_are_same_shape(a, b) && !is_node) {
+ return a;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);
+
+ result->op = GGML_OP_REPEAT_BACK;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_concat
+
+struct ggml_tensor * ggml_concat(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int dim) {
+ GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
+
+ int64_t ne[GGML_MAX_DIMS];
+ for (int d = 0; d < GGML_MAX_DIMS; ++d) {
+ if (d == dim) {
+ ne[d] = a->ne[d] + b->ne[d];
+ continue;
+ }
+ GGML_ASSERT(a->ne[d] == b->ne[d]);
+ ne[d] = a->ne[d];
+ }
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
+
+ ggml_set_op_params_i32(result, 0, dim);
+
+ result->op = GGML_OP_CONCAT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+// ggml_abs
+
+struct ggml_tensor * ggml_abs(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_ABS);
+}
+
+struct ggml_tensor * ggml_abs_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ABS);
+}
+
+// ggml_sgn
+
+struct ggml_tensor * ggml_sgn(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_SGN);
+}
+
+struct ggml_tensor * ggml_sgn_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SGN);
+}
+
+// ggml_neg
+
+struct ggml_tensor * ggml_neg(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_NEG);
+}
+
+struct ggml_tensor * ggml_neg_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_NEG);
+}
+
+// ggml_step
+
+struct ggml_tensor * ggml_step(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_STEP);
+}
+
+struct ggml_tensor * ggml_step_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_STEP);
+}
+
+// ggml_tanh
+
+struct ggml_tensor * ggml_tanh(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_TANH);
+}
+
+struct ggml_tensor * ggml_tanh_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TANH);
+}
+
+// ggml_elu
+
+struct ggml_tensor * ggml_elu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_ELU);
+}
+
+struct ggml_tensor * ggml_elu_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ELU);
+}
+
+// ggml_relu
+
+struct ggml_tensor * ggml_relu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_RELU);
+}
+
+struct ggml_tensor * ggml_relu_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
+}
+
+// ggml_leaky_relu
+
+struct ggml_tensor * ggml_leaky_relu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a, float negative_slope, bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
+
+ result->op = GGML_OP_LEAKY_RELU;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_sigmoid
+
+struct ggml_tensor * ggml_sigmoid(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);
+}
+
+struct ggml_tensor * ggml_sigmoid_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
+}
+
+// ggml_gelu
+
+struct ggml_tensor * ggml_gelu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_GELU);
+}
+
+struct ggml_tensor * ggml_gelu_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
+}
+
+// ggml_gelu_quick
+
+struct ggml_tensor * ggml_gelu_quick(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_QUICK);
+}
+
+struct ggml_tensor * ggml_gelu_quick_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_QUICK);
+}
+
+// ggml_silu
+
+struct ggml_tensor * ggml_silu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_SILU);
+}
+
+struct ggml_tensor * ggml_silu_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
+}
+
+// ggml_silu_back
+
+struct ggml_tensor * ggml_silu_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_SILU_BACK;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+// ggml hardswish
+struct ggml_tensor * ggml_hardswish(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSWISH);
+}
+
+// ggml hardsigmoid
+struct ggml_tensor * ggml_hardsigmoid(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
+}
+
+// ggml_norm
+
+static struct ggml_tensor * ggml_norm_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float eps,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params(result, &eps, sizeof(eps));
+
+ result->op = GGML_OP_NORM;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float eps) {
+ return ggml_norm_impl(ctx, a, eps, false);
+}
+
+struct ggml_tensor * ggml_norm_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float eps) {
+ return ggml_norm_impl(ctx, a, eps, true);
+}
+
+// ggml_rms_norm
+
+static struct ggml_tensor * ggml_rms_norm_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float eps,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params(result, &eps, sizeof(eps));
+
+ result->op = GGML_OP_RMS_NORM;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_rms_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float eps) {
+ return ggml_rms_norm_impl(ctx, a, eps, false);
+}
+
+struct ggml_tensor * ggml_rms_norm_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float eps) {
+ return ggml_rms_norm_impl(ctx, a, eps, true);
+}
+
+// ggml_rms_norm_back
+
+struct ggml_tensor * ggml_rms_norm_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float eps) {
+ bool is_node = false;
+
+ if (a->grad) {
+ // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params(result, &eps, sizeof(eps));
+
+ result->op = GGML_OP_RMS_NORM_BACK;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+// ggml_group_norm
+
+static struct ggml_tensor * ggml_group_norm_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_groups,
+ bool inplace) {
+
+ bool is_node = false;
+ if (!inplace && (a->grad)) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op_params[0] = n_groups;
+
+ result->op = GGML_OP_GROUP_NORM;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_group_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_groups) {
+ return ggml_group_norm_impl(ctx, a, n_groups, false);
+}
+
+struct ggml_tensor * ggml_group_norm_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_groups) {
+ return ggml_group_norm_impl(ctx, a, n_groups, true);
+}
+
+// ggml_mul_mat
+
+struct ggml_tensor * ggml_mul_mat(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ GGML_ASSERT(ggml_can_mul_mat(a, b));
+ GGML_ASSERT(!ggml_is_transposed(a));
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ is_node = true;
+ }
+
+ const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ result->op = GGML_OP_MUL_MAT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+void ggml_mul_mat_set_prec(
+ struct ggml_tensor * a,
+ enum ggml_prec prec) {
+ GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
+
+ const int32_t prec_i32 = (int32_t) prec;
+
+ ggml_set_op_params_i32(a, 0, prec_i32);
+}
+
+// ggml_mul_mat_id
+
+/*
+ c = ggml_mul_mat_id(ctx, as, b, ids);
+
+ as -> [cols, rows, n_expert]
+ ids -> [n_experts_used, n_tokens] (i32)
+ b -> [cols, n_expert_used, n_tokens]
+ c -> [rows, n_expert_used, n_tokens]
+
+ in b, n_experts_used can be broadcasted to match the n_expert_used of ids
+
+ c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
+*/
+struct ggml_tensor * ggml_mul_mat_id(
+ struct ggml_context * ctx,
+ struct ggml_tensor * as,
+ struct ggml_tensor * b,
+ struct ggml_tensor * ids) {
+ GGML_ASSERT(!ggml_is_transposed(as));
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
+
+ GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert)
+ GGML_ASSERT(b->ne[3] == 1); // b is 3d
+ GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
+ GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
+ GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
+ GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
+
+ bool is_node = false;
+
+ if (as->grad || b->grad) {
+ is_node = true;
+ }
+
+ const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ result->op = GGML_OP_MUL_MAT_ID;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = as;
+ result->src[1] = b;
+ result->src[2] = ids;
+
+ return result;
+}
+
+// ggml_out_prod
+
+struct ggml_tensor * ggml_out_prod(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ GGML_ASSERT(ggml_can_out_prod(a, b));
+ GGML_ASSERT(!ggml_is_transposed(a));
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ is_node = true;
+ }
+
+ // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
+ const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ result->op = GGML_OP_OUT_PROD;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+// ggml_scale
+
+static struct ggml_tensor * ggml_scale_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float s,
+ bool inplace) {
+ GGML_ASSERT(ggml_is_padded_1d(a));
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params(result, &s, sizeof(s));
+
+ result->op = GGML_OP_SCALE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_scale(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float s) {
+ return ggml_scale_impl(ctx, a, s, false);
+}
+
+struct ggml_tensor * ggml_scale_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float s) {
+ return ggml_scale_impl(ctx, a, s, true);
+}
+
+// ggml_set
+
+static struct ggml_tensor * ggml_set_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t nb1,
+ size_t nb2,
+ size_t nb3,
+ size_t offset,
+ bool inplace) {
+ GGML_ASSERT(ggml_nelements(a) >= ggml_nelements(b));
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ is_node = true;
+ }
+
+ // make a view of the destination
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_SET;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_set(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t nb1,
+ size_t nb2,
+ size_t nb3,
+ size_t offset) {
+ return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
+}
+
+struct ggml_tensor * ggml_set_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t nb1,
+ size_t nb2,
+ size_t nb3,
+ size_t offset) {
+ return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
+}
+
+struct ggml_tensor * ggml_set_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t offset) {
+ return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false);
+}
+
+struct ggml_tensor * ggml_set_1d_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t offset) {
+ return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true);
+}
+
+struct ggml_tensor * ggml_set_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t nb1,
+ size_t offset) {
+ return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
+}
+
+struct ggml_tensor * ggml_set_2d_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ size_t nb1,
+ size_t offset) {
+ return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true);
+}
+
+// ggml_cpy
+
+static struct ggml_tensor * ggml_cpy_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ // inplace is false and either one have a grad
+ is_node = true;
+ }
+
+ // make a view of the destination
+ struct ggml_tensor * result = ggml_view_tensor(ctx, b);
+ if (strlen(b->name) > 0) {
+ ggml_format_name(result, "%s (copy of %s)", b->name, a->name);
+ } else {
+ ggml_format_name(result, "%s (copy)", a->name);
+ }
+
+ result->op = GGML_OP_CPY;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_cpy(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_cpy_impl(ctx, a, b);
+}
+
+struct ggml_tensor * ggml_cast(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ enum ggml_type type) {
+ bool is_node = false;
+
+ struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
+ ggml_format_name(result, "%s (copy)", a->name);
+
+ result->op = GGML_OP_CPY;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = result;
+
+ return result;
+}
+
+// ggml_cont
+
+static struct ggml_tensor * ggml_cont_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+ ggml_format_name(result, "%s (cont)", a->name);
+
+ result->op = GGML_OP_CONT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_cont(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_cont_impl(ctx, a);
+}
+
+// make contiguous, with new shape
+GGML_API struct ggml_tensor * ggml_cont_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0) {
+ return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
+}
+
+GGML_API struct ggml_tensor * ggml_cont_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1) {
+ return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
+}
+
+GGML_API struct ggml_tensor * ggml_cont_3d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2) {
+ return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
+}
+
+struct ggml_tensor * ggml_cont_4d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2,
+ int64_t ne3) {
+ GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
+
+ bool is_node = false;
+
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
+ ggml_format_name(result, "%s (cont)", a->name);
+
+ result->op = GGML_OP_CONT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_reshape
+
+struct ggml_tensor * ggml_reshape(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ GGML_ASSERT(ggml_is_contiguous(a));
+ // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
+ GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ if (b->grad) {
+ // gradient propagation is not supported
+ //GGML_ASSERT(false);
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b->ne, a, 0);
+ ggml_format_name(result, "%s (reshaped)", a->name);
+
+ result->op = GGML_OP_RESHAPE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_reshape_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0) {
+ GGML_ASSERT(ggml_is_contiguous(a));
+ GGML_ASSERT(ggml_nelements(a) == ne0);
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ const int64_t ne[1] = { ne0 };
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0);
+ ggml_format_name(result, "%s (reshaped)", a->name);
+
+ result->op = GGML_OP_RESHAPE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_reshape_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1) {
+ GGML_ASSERT(ggml_is_contiguous(a));
+ GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ const int64_t ne[2] = { ne0, ne1 };
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0);
+ ggml_format_name(result, "%s (reshaped)", a->name);
+
+ result->op = GGML_OP_RESHAPE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_reshape_3d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2) {
+ GGML_ASSERT(ggml_is_contiguous(a));
+ GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ const int64_t ne[3] = { ne0, ne1, ne2 };
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0);
+ ggml_format_name(result, "%s (reshaped)", a->name);
+
+ result->op = GGML_OP_RESHAPE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_reshape_4d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2,
+ int64_t ne3) {
+ GGML_ASSERT(ggml_is_contiguous(a));
+ GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0);
+ ggml_format_name(result, "%s (reshaped)", a->name);
+
+ result->op = GGML_OP_RESHAPE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+static struct ggml_tensor * ggml_view_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_dims,
+ const int64_t * ne,
+ size_t offset) {
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset);
+ ggml_format_name(result, "%s (view)", a->name);
+
+ ggml_set_op_params(result, &offset, sizeof(offset));
+
+ result->op = GGML_OP_VIEW;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_view_1d
+
+struct ggml_tensor * ggml_view_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ size_t offset) {
+
+ struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset);
+
+ return result;
+}
+
+// ggml_view_2d
+
+struct ggml_tensor * ggml_view_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1,
+ size_t nb1,
+ size_t offset) {
+
+ const int64_t ne[2] = { ne0, ne1 };
+
+ struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset);
+
+ result->nb[1] = nb1;
+ result->nb[2] = result->nb[1]*ne1;
+ result->nb[3] = result->nb[2];
+
+ return result;
+}
+
+// ggml_view_3d
+
+struct ggml_tensor * ggml_view_3d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2,
+ size_t nb1,
+ size_t nb2,
+ size_t offset) {
+
+ const int64_t ne[3] = { ne0, ne1, ne2 };
+
+ struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset);
+
+ result->nb[1] = nb1;
+ result->nb[2] = nb2;
+ result->nb[3] = result->nb[2]*ne2;
+
+ return result;
+}
+
+// ggml_view_4d
+
+struct ggml_tensor * ggml_view_4d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2,
+ int64_t ne3,
+ size_t nb1,
+ size_t nb2,
+ size_t nb3,
+ size_t offset) {
+
+ const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
+
+ struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset);
+
+ result->nb[1] = nb1;
+ result->nb[2] = nb2;
+ result->nb[3] = nb3;
+
+ return result;
+}
+
+// ggml_permute
+
+struct ggml_tensor * ggml_permute(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int axis0,
+ int axis1,
+ int axis2,
+ int axis3) {
+ GGML_ASSERT(axis0 >= 0 && axis0 < GGML_MAX_DIMS);
+ GGML_ASSERT(axis1 >= 0 && axis1 < GGML_MAX_DIMS);
+ GGML_ASSERT(axis2 >= 0 && axis2 < GGML_MAX_DIMS);
+ GGML_ASSERT(axis3 >= 0 && axis3 < GGML_MAX_DIMS);
+
+ GGML_ASSERT(axis0 != axis1);
+ GGML_ASSERT(axis0 != axis2);
+ GGML_ASSERT(axis0 != axis3);
+ GGML_ASSERT(axis1 != axis2);
+ GGML_ASSERT(axis1 != axis3);
+ GGML_ASSERT(axis2 != axis3);
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+ ggml_format_name(result, "%s (permuted)", a->name);
+
+ int ne[GGML_MAX_DIMS];
+ int nb[GGML_MAX_DIMS];
+
+ ne[axis0] = a->ne[0];
+ ne[axis1] = a->ne[1];
+ ne[axis2] = a->ne[2];
+ ne[axis3] = a->ne[3];
+
+ nb[axis0] = a->nb[0];
+ nb[axis1] = a->nb[1];
+ nb[axis2] = a->nb[2];
+ nb[axis3] = a->nb[3];
+
+ result->ne[0] = ne[0];
+ result->ne[1] = ne[1];
+ result->ne[2] = ne[2];
+ result->ne[3] = ne[3];
+
+ result->nb[0] = nb[0];
+ result->nb[1] = nb[1];
+ result->nb[2] = nb[2];
+ result->nb[3] = nb[3];
+
+ result->op = GGML_OP_PERMUTE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ int32_t params[] = { axis0, axis1, axis2, axis3 };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ return result;
+}
+
+// ggml_transpose
+
+struct ggml_tensor * ggml_transpose(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+ ggml_format_name(result, "%s (transposed)", a->name);
+
+ result->ne[0] = a->ne[1];
+ result->ne[1] = a->ne[0];
+
+ result->nb[0] = a->nb[1];
+ result->nb[1] = a->nb[0];
+
+ result->op = GGML_OP_TRANSPOSE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_get_rows
+
+struct ggml_tensor * ggml_get_rows(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ GGML_ASSERT(a->ne[2] == b->ne[1]);
+ GGML_ASSERT(b->ne[3] == 1);
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ is_node = true;
+ }
+
+ // TODO: implement non F32 return
+ enum ggml_type type = GGML_TYPE_F32;
+ if (a->type == GGML_TYPE_I32) {
+ type = a->type;
+ }
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
+
+ result->op = GGML_OP_GET_ROWS;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+// ggml_get_rows_back
+
+struct ggml_tensor * ggml_get_rows_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c) {
+ GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
+ GGML_ASSERT(ggml_is_matrix(c) && (a->ne[0] == c->ne[0]));
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ is_node = true;
+ }
+
+ // TODO: implement non F32 return
+ //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
+ struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c->ne[0], c->ne[1]);
+
+ result->op = GGML_OP_GET_ROWS_BACK;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+// ggml_diag
+
+struct ggml_tensor * ggml_diag(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ GGML_ASSERT(a->ne[1] == 1);
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 4, ne);
+
+ result->op = GGML_OP_DIAG;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_diag_mask_inf
+
+static struct ggml_tensor * ggml_diag_mask_inf_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past,
+ bool inplace) {
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ int32_t params[] = { n_past };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_DIAG_MASK_INF;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_diag_mask_inf(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past) {
+ return ggml_diag_mask_inf_impl(ctx, a, n_past, false);
+}
+
+struct ggml_tensor * ggml_diag_mask_inf_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past) {
+ return ggml_diag_mask_inf_impl(ctx, a, n_past, true);
+}
+
+// ggml_diag_mask_zero
+
+static struct ggml_tensor * ggml_diag_mask_zero_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past,
+ bool inplace) {
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ int32_t params[] = { n_past };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_DIAG_MASK_ZERO;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_diag_mask_zero(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past) {
+ return ggml_diag_mask_zero_impl(ctx, a, n_past, false);
+}
+
+struct ggml_tensor * ggml_diag_mask_zero_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past) {
+ return ggml_diag_mask_zero_impl(ctx, a, n_past, true);
+}
+
+// ggml_soft_max
+
+static struct ggml_tensor * ggml_soft_max_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale,
+ float max_bias,
+ bool inplace) {
+ GGML_ASSERT(ggml_is_contiguous(a));
+
+ if (mask) {
+ GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguous(mask));
+ GGML_ASSERT(ggml_is_matrix(mask));
+ GGML_ASSERT(mask->ne[0] == a->ne[0]);
+ GGML_ASSERT(mask->ne[1] >= a->ne[1]);
+ }
+
+ if (max_bias > 0.0f) {
+ GGML_ASSERT(mask);
+ }
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ float params[] = { scale, max_bias };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_SOFT_MAX;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = mask;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_soft_max(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);
+}
+
+struct ggml_tensor * ggml_soft_max_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);
+}
+
+struct ggml_tensor * ggml_soft_max_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale,
+ float max_bias) {
+ return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
+}
+
+// ggml_soft_max_back
+
+static struct ggml_tensor * ggml_soft_max_back_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ bool inplace) {
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ is_node = true; // TODO : implement backward pass
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_SOFT_MAX_BACK;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_soft_max_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_soft_max_back_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_soft_max_back_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_soft_max_back_impl(ctx, a, b, true);
+}
+
+// ggml_rope
+
+static struct ggml_tensor * ggml_rope_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ int n_dims,
+ int mode,
+ int n_ctx_orig,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow,
+ bool inplace) {
+ GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
+
+ GGML_ASSERT(ggml_is_vector(b));
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
+ GGML_ASSERT(a->ne[2] == b->ne[0]);
+
+ if (c) {
+ GGML_ASSERT(c->type == GGML_TYPE_F32);
+ GGML_ASSERT(c->ne[0] >= n_dims / 2);
+ }
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
+ memcpy(params + 5, &freq_base, sizeof(float));
+ memcpy(params + 6, &freq_scale, sizeof(float));
+ memcpy(params + 7, &ext_factor, sizeof(float));
+ memcpy(params + 8, &attn_factor, sizeof(float));
+ memcpy(params + 9, &beta_fast, sizeof(float));
+ memcpy(params + 10, &beta_slow, sizeof(float));
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_ROPE;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = c;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_rope(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int n_dims,
+ int mode) {
+ return ggml_rope_impl(
+ ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
+ );
+}
+
+struct ggml_tensor * ggml_rope_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int n_dims,
+ int mode) {
+ return ggml_rope_impl(
+ ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
+ );
+}
+
+struct ggml_tensor * ggml_rope_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ int n_dims,
+ int mode,
+ int n_ctx_orig,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow) {
+ return ggml_rope_impl(
+ ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow, false
+ );
+}
+
+struct ggml_tensor * ggml_rope_ext_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ int n_dims,
+ int mode,
+ int n_ctx_orig,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow) {
+ return ggml_rope_impl(
+ ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow, true
+ );
+}
+
+struct ggml_tensor * ggml_rope_custom(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int n_dims,
+ int mode,
+ int n_ctx_orig,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow) {
+ return ggml_rope_impl(
+ ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow, false
+ );
+}
+
+struct ggml_tensor * ggml_rope_custom_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int n_dims,
+ int mode,
+ int n_ctx_orig,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow) {
+ return ggml_rope_impl(
+ ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow, true
+ );
+}
+
+// ggml_rope_back
+
+struct ggml_tensor * ggml_rope_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ int n_dims,
+ int mode,
+ int n_ctx_orig,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow) {
+ GGML_ASSERT(ggml_is_vector(b));
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
+ GGML_ASSERT(a->ne[2] == b->ne[0]);
+ GGML_ASSERT(c == NULL && "freq factors not implemented yet");
+
+ GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
+
+ bool is_node = false;
+
+ if (a->grad) {
+ is_node = false; // TODO: implement backward
+ }
+
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+ int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
+ memcpy(params + 5, &freq_base, sizeof(float));
+ memcpy(params + 6, &freq_scale, sizeof(float));
+ memcpy(params + 7, &ext_factor, sizeof(float));
+ memcpy(params + 8, &attn_factor, sizeof(float));
+ memcpy(params + 9, &beta_fast, sizeof(float));
+ memcpy(params + 10, &beta_slow, sizeof(float));
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_ROPE_BACK;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+// ggml_clamp
+
+struct ggml_tensor * ggml_clamp(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float min,
+ float max) {
+ bool is_node = false;
+
+ if (a->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ // TODO: when implement backward, fix this:
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+ float params[] = { min, max };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_CLAMP;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_conv_1d
+
+static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
+ return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
+}
+
+GGML_API struct ggml_tensor * ggml_conv_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int s0,
+ int p0,
+ int d0) {
+ struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
+
+ struct ggml_tensor * result =
+ ggml_mul_mat(ctx,
+ ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
+ ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
+
+ result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
+
+ return result;
+}
+
+// ggml_conv_1d_ph
+
+struct ggml_tensor* ggml_conv_1d_ph(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int s,
+ int d) {
+ return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
+}
+
+// ggml_conv_transpose_1d
+
+static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
+ return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
+}
+
+GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int s0,
+ int p0,
+ int d0) {
+ GGML_ASSERT(ggml_is_matrix(b));
+ GGML_ASSERT(a->ne[2] == b->ne[1]);
+ GGML_ASSERT(a->ne[3] == 1);
+
+ GGML_ASSERT(p0 == 0);
+ GGML_ASSERT(d0 == 1);
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ const int64_t ne[4] = {
+ ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
+ a->ne[1], b->ne[2], 1,
+ };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ int32_t params[] = { s0, p0, d0 };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_CONV_TRANSPOSE_1D;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+// ggml_conv_depthwise
+struct ggml_tensor * ggml_conv_depthwise_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int s0,
+ int s1,
+ int p0,
+ int p1,
+ int d0,
+ int d1) {
+
+ struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
+ struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
+ ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
+ s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
+ struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
+
+ new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
+ struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
+ result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
+
+ return result;
+}
+// ggml_conv_2d
+
+// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+// a: [OC,IC, KH, KW]
+// b: [N, IC, IH, IW]
+// result: [N, OH, OW, IC*KH*KW]
+struct ggml_tensor * ggml_im2col(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int s0,
+ int s1,
+ int p0,
+ int p1,
+ int d0,
+ int d1,
+ bool is_2D,
+ enum ggml_type dst_type) {
+
+ if(is_2D) {
+ GGML_ASSERT(a->ne[2] == b->ne[2]);
+ } else {
+ GGML_ASSERT(a->ne[1] == b->ne[1]);
+ }
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
+ const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
+
+ const int64_t ne[4] = {
+ is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
+ OW,
+ is_2D ? OH : b->ne[2],
+ is_2D ? b->ne[3] : 1,
+ };
+
+ struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
+ int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_IM2COL;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+// a: [OC,IC, KH, KW]
+// b: [N, IC, IH, IW]
+// result: [N, OC, OH, OW]
+struct ggml_tensor * ggml_conv_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int s0,
+ int s1,
+ int p0,
+ int p1,
+ int d0,
+ int d1) {
+ struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW]
+
+ struct ggml_tensor * result =
+ ggml_mul_mat(ctx,
+ ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
+ ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW]
+
+ result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW]
+ result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW]
+
+
+ return result;
+}
+
+// ggml_conv_2d_sk_p0
+struct ggml_tensor * ggml_conv_2d_sk_p0(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1);
+}
+
+// ggml_conv_2d_s1_ph
+
+struct ggml_tensor * ggml_conv_2d_s1_ph(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);
+}
+
+// ggml_conv_transpose_2d_p0
+
+static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
+ return (ins - 1) * s - 2 * p + ks;
+}
+
+struct ggml_tensor * ggml_conv_transpose_2d_p0(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int stride) {
+ GGML_ASSERT(a->ne[3] == b->ne[2]);
+
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ const int64_t ne[4] = {
+ ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/),
+ ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/),
+ a->ne[2], b->ne[3],
+ };
+
+ struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ ggml_set_op_params_i32(result, 0, stride);
+
+ result->op = GGML_OP_CONV_TRANSPOSE_2D;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+// ggml_pool_*
+
+static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) {
+ return (ins + 2 * p - ks) / s + 1;
+}
+
+// ggml_pool_1d
+
+struct ggml_tensor * ggml_pool_1d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ enum ggml_op_pool op,
+ int k0,
+ int s0,
+ int p0) {
+
+ bool is_node = false;
+
+ if (a->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ const int64_t ne[4] = {
+ ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
+ a->ne[1],
+ a->ne[2],
+ a->ne[3],
+ };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ int32_t params[] = { op, k0, s0, p0 };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_POOL_1D;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_pool_2d
+
+struct ggml_tensor * ggml_pool_2d(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ enum ggml_op_pool op,
+ int k0,
+ int k1,
+ int s0,
+ int s1,
+ float p0,
+ float p1) {
+
+ bool is_node = false;
+
+ if (a->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result;
+ const int64_t ne[3] = {
+ ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
+ ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
+ a->ne[2],
+ };
+ result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
+
+ int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_POOL_2D;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ return result;
+}
+
+// ggml_upscale
+
+static struct ggml_tensor * ggml_upscale_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ int ne1,
+ int ne2,
+ int ne3) {
+ bool is_node = false;
+
+ if (a->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ GGML_ASSERT(a->ne[0] <= ne0);
+ GGML_ASSERT(a->ne[1] <= ne1);
+ GGML_ASSERT(a->ne[2] <= ne2);
+ GGML_ASSERT(a->ne[3] <= ne3);
+
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
+ ne0,
+ ne1,
+ ne2,
+ ne3
+ );
+
+ result->op = GGML_OP_UPSCALE;
+
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_upscale(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int scale_factor) {
+ return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
+}
+
+struct ggml_tensor * ggml_upscale_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int ne0,
+ int ne1,
+ int ne2,
+ int ne3) {
+ return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
+}
+
+// ggml_pad
+
+struct ggml_tensor * ggml_pad(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int p0, int p1, int p2, int p3) {
+ bool is_node = false;
+
+ if (a->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
+ a->ne[0] + p0,
+ a->ne[1] + p1,
+ a->ne[2] + p2,
+ a->ne[3] + p3);
+
+ result->op = GGML_OP_PAD;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_arange
+
+struct ggml_tensor * ggml_arange(
+ struct ggml_context * ctx,
+ float start,
+ float stop,
+ float step) {
+
+ GGML_ASSERT(stop > start);
+
+ const int64_t steps = (int64_t) ceilf((stop - start) / step);
+
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
+
+ result->op = GGML_OP_ARANGE;
+ ggml_set_op_params_f32(result, 0, start);
+ ggml_set_op_params_f32(result, 1, stop);
+ ggml_set_op_params_f32(result, 2, step);
+
+ return result;
+}
+
+// ggml_timestep_embedding
+
+struct ggml_tensor * ggml_timestep_embedding(
+ struct ggml_context * ctx,
+ struct ggml_tensor * timesteps,
+ int dim,
+ int max_period) {
+ bool is_node = false;
+
+ if (timesteps->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ int actual_dim = dim;
+ if (dim % 2 != 0) {
+ actual_dim = dim + 1;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]);
+
+ result->op = GGML_OP_TIMESTEP_EMBEDDING;
+ ggml_set_op_params_i32(result, 0, dim);
+ ggml_set_op_params_i32(result, 1, max_period);
+
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = timesteps;
+
+ return result;
+}
+
+// ggml_argsort
+
+struct ggml_tensor * ggml_argsort(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ enum ggml_sort_order order) {
+ bool is_node = false;
+
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
+
+ ggml_set_op_params_i32(result, 0, (int32_t) order);
+
+ result->op = GGML_OP_ARGSORT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_top_k
+
+struct ggml_tensor * ggml_top_k(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int k) {
+ GGML_ASSERT(a->ne[0] >= k);
+
+ struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
+
+ result = ggml_view_4d(ctx, result,
+ k, result->ne[1], result->ne[2], result->ne[3],
+ result->nb[1], result->nb[2], result->nb[3],
+ 0);
+
+ return result;
+}
+
+// ggml_flash_attn_ext
+
+struct ggml_tensor * ggml_flash_attn_ext(
+ struct ggml_context * ctx,
+ struct ggml_tensor * q,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * mask,
+ float scale,
+ float max_bias) {
+ GGML_ASSERT(ggml_can_mul_mat(k, q));
+ // TODO: check if vT can be multiplied by (k*qT)
+
+ if (mask) {
+ GGML_ASSERT(ggml_is_contiguous(mask));
+ GGML_ASSERT(mask->ne[2] == 1);
+ GGML_ASSERT(mask->ne[3] == 1);
+ GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
+ "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
+ //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
+ }
+
+ if (max_bias > 0.0f) {
+ GGML_ASSERT(mask);
+ }
+
+ bool is_node = false;
+
+ if (q->grad || k->grad || v->grad) {
+ is_node = true;
+ }
+
+ // permute(0, 2, 1, 3)
+ int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ float params[] = { scale, max_bias };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_FLASH_ATTN_EXT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = q;
+ result->src[1] = k;
+ result->src[2] = v;
+ result->src[3] = mask;
+
+ return result;
+}
+
+void ggml_flash_attn_ext_set_prec(
+ struct ggml_tensor * a,
+ enum ggml_prec prec) {
+ GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
+
+ const int32_t prec_i32 = (int32_t) prec;
+
+ ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
+}
+
+// ggml_flash_attn_back
+
+struct ggml_tensor * ggml_flash_attn_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * q,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * d,
+ bool masked) {
+ GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
+
+ GGML_ASSERT(ggml_can_mul_mat(k, q));
+ // TODO: check if vT can be multiplied by (k*qT)
+
+ // d shape [D,N,ne2,ne3]
+ // q shape [D,N,ne2,ne3]
+ // k shape [D,M,kvne2,ne3]
+ // v shape [M,D,kvne2,ne3]
+
+ const int64_t D = q->ne[0];
+ const int64_t N = q->ne[1];
+ const int64_t M = k->ne[1];
+ const int64_t ne2 = q->ne[2];
+ const int64_t ne3 = q->ne[3];
+ const int64_t kvne2 = k->ne[2];
+
+ GGML_ASSERT(k->ne[0] == D);
+ GGML_ASSERT(v->ne[0] == M);
+ GGML_ASSERT(v->ne[1] == D);
+ GGML_ASSERT(d->ne[0] == D);
+ GGML_ASSERT(d->ne[1] == N);
+ GGML_ASSERT(k->ne[2] == kvne2);
+ GGML_ASSERT(k->ne[3] == ne3);
+ GGML_ASSERT(v->ne[2] == kvne2);
+ GGML_ASSERT(v->ne[3] == ne3);
+ GGML_ASSERT(d->ne[2] == ne2);
+ GGML_ASSERT(d->ne[3] == ne3);
+
+ GGML_ASSERT(ne2 % kvne2 == 0);
+
+ bool is_node = false;
+
+ if (q->grad || k->grad || v->grad) {
+ // when using this operation (in backwards pass) these grads are set.
+ // we don't want to create (big) grad of our result, so is_node is false.
+ is_node = false;
+ }
+
+ // store gradients of q, k and v as continuous tensors concatenated in result.
+ // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
+ const int64_t elem_q = ggml_nelements(q);
+ const int64_t elem_k = ggml_nelements(k);
+ const int64_t elem_v = ggml_nelements(v);
+
+ enum ggml_type result_type = GGML_TYPE_F32;
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
+ const size_t tsize = ggml_type_size(result_type);
+
+ const size_t offs_q = 0;
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
+ const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
+
+ const size_t nelements = (end + tsize - 1)/tsize;
+
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
+
+ int32_t masked_i = masked ? 1 : 0;
+ ggml_set_op_params(result, &masked_i, sizeof(masked_i));
+
+ result->op = GGML_OP_FLASH_ATTN_BACK;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = q;
+ result->src[1] = k;
+ result->src[2] = v;
+ result->src[3] = d;
+
+ return result;
+}
+
+// ggml_ssm_conv
+
+struct ggml_tensor * ggml_ssm_conv(
+ struct ggml_context * ctx,
+ struct ggml_tensor * s,
+ struct ggml_tensor * x,
+ struct ggml_tensor * c,
+ struct ggml_tensor * sq) {
+ GGML_ASSERT(ggml_is_3d(s));
+ GGML_ASSERT(ggml_is_matrix(x));
+ GGML_ASSERT(ggml_is_matrix(c));
+ GGML_ASSERT(ggml_is_matrix(sq));
+ GGML_ASSERT(sq->type == GGML_TYPE_I32);
+
+ const int64_t d_conv = c->ne[0];
+ const int64_t d_inner = c->ne[1];
+ const int64_t n_tokens = x->ne[1];
+ const int64_t n_kv = s->ne[2];
+
+ GGML_ASSERT( s->ne[0] == d_conv - 1);
+ GGML_ASSERT( s->ne[1] == d_inner);
+ GGML_ASSERT( x->ne[0] == d_inner);
+ GGML_ASSERT(sq->ne[0] == n_kv);
+ GGML_ASSERT(sq->ne[1] == n_tokens);
+
+ bool is_node = false;
+
+ if (s->grad || x->grad || c->grad || sq->grad) {
+ GGML_ASSERT(false); // TODO: implement
+ is_node = true;
+ }
+
+ // 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
+
+ result->op = GGML_OP_SSM_CONV;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = s;
+ result->src[1] = x;
+ result->src[2] = c;
+ result->src[3] = sq;
+
+ return result;
+}
+
+// ggml_ssm_scan
+
+struct ggml_tensor * ggml_ssm_scan(
+ struct ggml_context * ctx,
+ struct ggml_tensor * s,
+ struct ggml_tensor * x,
+ struct ggml_tensor * dt,
+ struct ggml_tensor * A,
+ struct ggml_tensor * B,
+ struct ggml_tensor * C,
+ struct ggml_tensor * sq) {
+ GGML_ASSERT(ggml_is_contiguous(s));
+ GGML_ASSERT(ggml_is_contiguous(x));
+ GGML_ASSERT(ggml_is_contiguous(dt));
+ GGML_ASSERT(ggml_is_contiguous(A));
+ GGML_ASSERT(sq->type == GGML_TYPE_I32);
+ GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
+ GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
+ GGML_ASSERT(ggml_are_same_shape(x, dt));
+
+ {
+ const int64_t d_state = s->ne[0];
+ const int64_t d_inner = s->ne[1];
+ const int64_t n_tokens = x->ne[1];
+
+ GGML_ASSERT(x->ne[0] == d_inner);
+ GGML_ASSERT(A->ne[0] == d_state);
+ GGML_ASSERT(A->ne[1] == d_inner);
+ GGML_ASSERT(B->ne[0] == d_state);
+ GGML_ASSERT(B->ne[1] == n_tokens);
+ GGML_ASSERT(C->ne[0] == d_state);
+ GGML_ASSERT(C->ne[1] == n_tokens);
+ }
+
+ bool is_node = false;
+
+ if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) {
+ GGML_ASSERT(false); // TODO: implement
+ is_node = true;
+ }
+
+ // 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
+
+ result->op = GGML_OP_SSM_SCAN;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = s;
+ result->src[1] = x;
+ result->src[2] = dt;
+ result->src[3] = A;
+ result->src[4] = B;
+ result->src[5] = C;
+ result->src[6] = sq;
+
+ return result;
+}
+
+// ggml_win_part
+
+struct ggml_tensor * ggml_win_part(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int w) {
+ GGML_ASSERT(a->ne[3] == 1);
+ GGML_ASSERT(a->type == GGML_TYPE_F32);
+
+ bool is_node = false;
+
+ if (a->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ // padding
+ const int px = (w - a->ne[1]%w)%w;
+ const int py = (w - a->ne[2]%w)%w;
+
+ const int npx = (px + a->ne[1])/w;
+ const int npy = (py + a->ne[2])/w;
+ const int np = npx*npy;
+
+ const int64_t ne[4] = { a->ne[0], w, w, np, };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ int32_t params[] = { npx, npy, w };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_WIN_PART;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_win_unpart
+
+struct ggml_tensor * ggml_win_unpart(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int w0,
+ int h0,
+ int w) {
+ GGML_ASSERT(a->type == GGML_TYPE_F32);
+
+ bool is_node = false;
+
+ if (a->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
+
+ int32_t params[] = { w };
+ ggml_set_op_params(result, params, sizeof(params));
+
+ result->op = GGML_OP_WIN_UNPART;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_get_rel_pos
+
+struct ggml_tensor * ggml_get_rel_pos(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int qh,
+ int kh) {
+ GGML_ASSERT(qh == kh);
+ GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]);
+
+ bool is_node = false;
+
+ if (a->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ const int64_t ne[4] = { a->ne[0], kh, qh, 1, };
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne);
+
+ result->op = GGML_OP_GET_REL_POS;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+// ggml_add_rel_pos
+
+static struct ggml_tensor * ggml_add_rel_pos_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * pw,
+ struct ggml_tensor * ph,
+ bool inplace) {
+ GGML_ASSERT(ggml_are_same_shape(pw, ph));
+ GGML_ASSERT(ggml_is_contiguous(a));
+ GGML_ASSERT(ggml_is_contiguous(pw));
+ GGML_ASSERT(ggml_is_contiguous(ph));
+ GGML_ASSERT(ph->type == GGML_TYPE_F32);
+ GGML_ASSERT(pw->type == GGML_TYPE_F32);
+ GGML_ASSERT(pw->ne[3] == a->ne[2]);
+ GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]);
+ GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]);
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || pw->grad || ph->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+ ggml_set_op_params_i32(result, 0, inplace ? 1 : 0);
+
+ result->op = GGML_OP_ADD_REL_POS;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = pw;
+ result->src[2] = ph;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_add_rel_pos(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * pw,
+ struct ggml_tensor * ph) {
+ return ggml_add_rel_pos_impl(ctx, a, pw, ph, false);
+}
+
+struct ggml_tensor * ggml_add_rel_pos_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * pw,
+ struct ggml_tensor * ph) {
+ return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
+}
+
+// ggml_unary
+
+static struct ggml_tensor * ggml_unary_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ enum ggml_unary_op op,
+ bool inplace) {
+ GGML_ASSERT(ggml_is_contiguous_1(a));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params_i32(result, 0, (int32_t) op);
+
+ result->op = GGML_OP_UNARY;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_unary(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ enum ggml_unary_op op) {
+ return ggml_unary_impl(ctx, a, op, false);
+}
+
+struct ggml_tensor * ggml_unary_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ enum ggml_unary_op op) {
+ return ggml_unary_impl(ctx, a, op, true);
+}
+
+// ggml_map_unary
+
+static struct ggml_tensor * ggml_map_unary_impl_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ const ggml_unary_op_f32_t fun,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
+
+ result->op = GGML_OP_MAP_UNARY;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_map_unary_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ const ggml_unary_op_f32_t fun) {
+ return ggml_map_unary_impl_f32(ctx, a, fun, false);
+}
+
+struct ggml_tensor * ggml_map_unary_inplace_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ const ggml_unary_op_f32_t fun) {
+ return ggml_map_unary_impl_f32(ctx, a, fun, true);
+}
+
+// ggml_map_binary
+
+static struct ggml_tensor * ggml_map_binary_impl_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ const ggml_binary_op_f32_t fun,
+ bool inplace) {
+ GGML_ASSERT(ggml_are_same_shape(a, b));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
+
+ result->op = GGML_OP_MAP_BINARY;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_map_binary_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ const ggml_binary_op_f32_t fun) {
+ return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
+}
+
+struct ggml_tensor * ggml_map_binary_inplace_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ const ggml_binary_op_f32_t fun) {
+ return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
+}
+
+// ggml_map_custom1_f32
+
+static struct ggml_tensor * ggml_map_custom1_impl_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ const ggml_custom1_op_f32_t fun,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
+
+ result->op = GGML_OP_MAP_CUSTOM1_F32;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_map_custom1_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ const ggml_custom1_op_f32_t fun) {
+ return ggml_map_custom1_impl_f32(ctx, a, fun, false);
+}
+
+struct ggml_tensor * ggml_map_custom1_inplace_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ const ggml_custom1_op_f32_t fun) {
+ return ggml_map_custom1_impl_f32(ctx, a, fun, true);
+}
+
+// ggml_map_custom2_f32
+
+static struct ggml_tensor * ggml_map_custom2_impl_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ const ggml_custom2_op_f32_t fun,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
+
+ result->op = GGML_OP_MAP_CUSTOM2_F32;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_map_custom2_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ const ggml_custom2_op_f32_t fun) {
+ return ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
+}
+
+struct ggml_tensor * ggml_map_custom2_inplace_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ const ggml_custom2_op_f32_t fun) {
+ return ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
+}
+
+// ggml_map_custom3_f32
+
+static struct ggml_tensor * ggml_map_custom3_impl_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ const ggml_custom3_op_f32_t fun,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad || c->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
+
+ result->op = GGML_OP_MAP_CUSTOM3_F32;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = c;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_map_custom3_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ const ggml_custom3_op_f32_t fun) {
+ return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
+}
+
+struct ggml_tensor * ggml_map_custom3_inplace_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ const ggml_custom3_op_f32_t fun) {
+ return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
+}
+
+// ggml_map_custom1
+struct ggml_map_custom1_op_params {
+ ggml_custom1_op_t fun;
+ int n_tasks;
+ void * userdata;
+};
+
+static struct ggml_tensor * ggml_map_custom1_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ const ggml_custom1_op_t fun,
+ int n_tasks,
+ void * userdata,
+ bool inplace) {
+ GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
+
+ bool is_node = false;
+
+ if (!inplace && a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ struct ggml_map_custom1_op_params params = {
+ /*.fun =*/ fun,
+ /*.n_tasks =*/ n_tasks,
+ /*.userdata =*/ userdata
+ };
+ ggml_set_op_params(result, (const void *) &params, sizeof(params));
+
+ result->op = GGML_OP_MAP_CUSTOM1;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_map_custom1(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ const ggml_custom1_op_t fun,
+ int n_tasks,
+ void * userdata) {
+ return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false);
+}
+
+struct ggml_tensor * ggml_map_custom1_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ const ggml_custom1_op_t fun,
+ int n_tasks,
+ void * userdata) {
+ return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true);
+}
+
+// ggml_map_custom2
+
+struct ggml_map_custom2_op_params {
+ ggml_custom2_op_t fun;
+ int n_tasks;
+ void * userdata;
+};
+
+static struct ggml_tensor * ggml_map_custom2_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ const ggml_custom2_op_t fun,
+ int n_tasks,
+ void * userdata,
+ bool inplace) {
+ GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ struct ggml_map_custom2_op_params params = {
+ /*.fun =*/ fun,
+ /*.n_tasks =*/ n_tasks,
+ /*.userdata =*/ userdata
+ };
+ ggml_set_op_params(result, (const void *) &params, sizeof(params));
+
+ result->op = GGML_OP_MAP_CUSTOM2;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_map_custom2(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ const ggml_custom2_op_t fun,
+ int n_tasks,
+ void * userdata) {
+ return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false);
+}
+
+struct ggml_tensor * ggml_map_custom2_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ const ggml_custom2_op_t fun,
+ int n_tasks,
+ void * userdata) {
+ return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true);
+}
+
+// ggml_map_custom3
+
+struct ggml_map_custom3_op_params {
+ ggml_custom3_op_t fun;
+ int n_tasks;
+ void * userdata;
+};
+
+static struct ggml_tensor * ggml_map_custom3_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ const ggml_custom3_op_t fun,
+ int n_tasks,
+ void * userdata,
+ bool inplace) {
+ GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad || c->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ struct ggml_map_custom3_op_params params = {
+ /*.fun =*/ fun,
+ /*.n_tasks =*/ n_tasks,
+ /*.userdata =*/ userdata
+ };
+ ggml_set_op_params(result, (const void *) &params, sizeof(params));
+
+ result->op = GGML_OP_MAP_CUSTOM3;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = c;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_map_custom3(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ const ggml_custom3_op_t fun,
+ int n_tasks,
+ void * userdata) {
+ return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false);
+}
+
+struct ggml_tensor * ggml_map_custom3_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c,
+ const ggml_custom3_op_t fun,
+ int n_tasks,
+ void * userdata) {
+ return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
+}
+
+// ggml_cross_entropy_loss
+
+struct ggml_tensor * ggml_cross_entropy_loss(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b) {
+ GGML_ASSERT(ggml_are_same_shape(a, b));
+ bool is_node = false;
+
+ if (a->grad || b->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
+
+ result->op = GGML_OP_CROSS_ENTROPY_LOSS;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+// ggml_cross_entropy_loss_back
+
+struct ggml_tensor * ggml_cross_entropy_loss_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c) {
+ GGML_ASSERT(ggml_are_same_shape(a, b));
+ GGML_ASSERT(ggml_is_scalar(c));
+
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
+ result->grad = NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = c;
+
+ return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_set_param(
+ struct ggml_context * ctx,
+ struct ggml_tensor * tensor) {
+ tensor->flags |= GGML_TENSOR_FLAG_PARAM;
+
+ GGML_ASSERT(tensor->grad == NULL);
+ tensor->grad = ggml_dup_tensor(ctx, tensor);
+ ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
+}
+
+// ggml_compute_forward_dup
+
+static void ggml_compute_forward_dup_same_cont(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+ GGML_ASSERT(src0->type == dst->type);
+
+ const size_t nb00 = src0->nb[0];
+ const size_t nb0 = dst->nb[0];
+
+ const int ith = params->ith; // thread index
+ const int nth = params->nth; // number of threads
+
+ // parallelize by elements
+ const int ne = ggml_nelements(dst);
+ const int dr = (ne + nth - 1) / nth;
+ const int ie0 = dr * ith;
+ const int ie1 = MIN(ie0 + dr, ne);
+
+ if (ie0 < ie1) {
+ memcpy(
+ ((char *) dst->data + ie0*nb0),
+ ((char *) src0->data + ie0*nb00),
+ (ie1 - ie0) * ggml_type_size(src0->type));
+ }
+}
+
+static void ggml_compute_forward_dup_f16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ const int ith = params->ith; // thread index
+ const int nth = params->nth; // number of threads
+
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
+ ggml_compute_forward_dup_same_cont(params, dst);
+ return;
+ }
+
+ // parallelize by rows
+ const int nr = ne01;
+ // number of rows per thread
+ const int dr = (nr + nth - 1) / nth;
+ // row range for this thread
+ const int ir0 = dr * ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ if (src0->type == dst->type &&
+ ne00 == ne0 &&
+ nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
+ // copy by rows
+ const size_t rs = ne00*nb00;
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ memcpy(
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+ rs);
+ }
+ }
+ }
+ return;
+ }
+
+ // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
+
+ if (ggml_is_contiguous(dst)) {
+ if (nb00 == sizeof(ggml_fp16_t)) {
+ if (dst->type == GGML_TYPE_F16) {
+ size_t id = 0;
+ const size_t rs = ne00 * nb00;
+ char * dst_ptr = (char *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += rs * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+ memcpy(dst_ptr + id, src0_ptr, rs);
+ id += rs;
+ }
+ id += rs * (ne01 - ir1);
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F32) {
+ size_t id = 0;
+ float * dst_ptr = (float *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += ne00 * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+ for (int i00 = 0; i00 < ne00; i00++) {
+ dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
+ id++;
+ }
+ }
+ id += ne00 * (ne01 - ir1);
+ }
+ }
+ } else if (type_traits[dst->type].from_float) {
+ ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+ size_t id = 0;
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+ char * dst_ptr = (char *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += rs * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+ for (int i00 = 0; i00 < ne00; i00++) {
+ src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
+ }
+
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
+ id += rs;
+ }
+ id += rs * (ne01 - ir1);
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ } else {
+ //printf("%s: this is not optimal - fix me\n", __func__);
+
+ if (dst->type == GGML_TYPE_F32) {
+ size_t id = 0;
+ float * dst_ptr = (float *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += ne00 * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+ id++;
+ }
+ }
+ id += ne00 * (ne01 - ir1);
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F16) {
+ size_t id = 0;
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += ne00 * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = *src0_ptr;
+ id++;
+ }
+ }
+ id += ne00 * (ne01 - ir1);
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ }
+ return;
+ }
+
+ // dst counters
+ int64_t i10 = 0;
+ int64_t i11 = 0;
+ int64_t i12 = 0;
+ int64_t i13 = 0;
+
+ if (dst->type == GGML_TYPE_F16) {
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ i10 += ne00 * ir0;
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
+
+ memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
+
+ if (++i10 == ne00) {
+ i10 = 0;
+ if (++i11 == ne01) {
+ i11 = 0;
+ if (++i12 == ne02) {
+ i12 = 0;
+ if (++i13 == ne03) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ i10 += ne00 * (ne01 - ir1);
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F32) {
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ i10 += ne00 * ir0;
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
+
+ *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
+
+ if (++i10 == ne0) {
+ i10 = 0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ i10 += ne00 * (ne01 - ir1);
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+}
+
+static void ggml_compute_forward_dup_bf16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ const int ith = params->ith; // thread index
+ const int nth = params->nth; // number of threads
+
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
+ ggml_compute_forward_dup_same_cont(params, dst);
+ return;
+ }
+
+ // parallelize by rows
+ const int nr = ne01;
+ // number of rows per thread
+ const int dr = (nr + nth - 1) / nth;
+ // row range for this thread
+ const int ir0 = dr * ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ if (src0->type == dst->type &&
+ ne00 == ne0 &&
+ nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
+ // copy by rows
+ const size_t rs = ne00*nb00;
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ memcpy(
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+ rs);
+ }
+ }
+ }
+ return;
+ }
+
+ // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
+
+ if (ggml_is_contiguous(dst)) {
+ if (nb00 == sizeof(ggml_bf16_t)) {
+ if (dst->type == GGML_TYPE_BF16) {
+ size_t id = 0;
+ const size_t rs = ne00 * nb00;
+ char * dst_ptr = (char *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += rs * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+ memcpy(dst_ptr + id, src0_ptr, rs);
+ id += rs;
+ }
+ id += rs * (ne01 - ir1);
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F16) {
+ size_t id = 0;
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += ne00 * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+ for (int i00 = 0; i00 < ne00; i00++) {
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
+ id++;
+ }
+ }
+ id += ne00 * (ne01 - ir1);
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F32) {
+ size_t id = 0;
+ float * dst_ptr = (float *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += ne00 * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+ for (int i00 = 0; i00 < ne00; i00++) {
+ dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
+ id++;
+ }
+ }
+ id += ne00 * (ne01 - ir1);
+ }
+ }
+ } else if (type_traits[dst->type].from_float) {
+ ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+ size_t id = 0;
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+ char * dst_ptr = (char *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += rs * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+ for (int i00 = 0; i00 < ne00; i00++) {
+ src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
+ }
+
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
+ id += rs;
+ }
+ id += rs * (ne01 - ir1);
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ } else {
+ //printf("%s: this is not optimal - fix me\n", __func__);
+
+ if (dst->type == GGML_TYPE_F32) {
+ size_t id = 0;
+ float * dst_ptr = (float *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += ne00 * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
+ id++;
+ }
+ }
+ id += ne00 * (ne01 - ir1);
+ }
+ }
+ } else if (dst->type == GGML_TYPE_BF16) {
+ size_t id = 0;
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += ne00 * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = *src0_ptr;
+ id++;
+ }
+ }
+ id += ne00 * (ne01 - ir1);
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F16) {
+ size_t id = 0;
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += ne00 * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
+ id++;
+ }
+ }
+ id += ne00 * (ne01 - ir1);
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ }
+ return;
+ }
+
+ // dst counters
+ int64_t i10 = 0;
+ int64_t i11 = 0;
+ int64_t i12 = 0;
+ int64_t i13 = 0;
+
+ if (dst->type == GGML_TYPE_BF16) {
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ i10 += ne00 * ir0;
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
+
+ memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
+
+ if (++i10 == ne00) {
+ i10 = 0;
+ if (++i11 == ne01) {
+ i11 = 0;
+ if (++i12 == ne02) {
+ i12 = 0;
+ if (++i13 == ne03) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ i10 += ne00 * (ne01 - ir1);
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F16) {
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ i10 += ne00 * ir0;
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
+
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
+
+ if (++i10 == ne0) {
+ i10 = 0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ i10 += ne00 * (ne01 - ir1);
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F32) {
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ i10 += ne00 * ir0;
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
+
+ *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
+
+ if (++i10 == ne0) {
+ i10 = 0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ i10 += ne00 * (ne01 - ir1);
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+}
+
+static void ggml_compute_forward_dup_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ const int ith = params->ith; // thread index
+ const int nth = params->nth; // number of threads
+
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
+ ggml_compute_forward_dup_same_cont(params, dst);
+ return;
+ }
+
+ // parallelize by rows
+ const int nr = ne01;
+ // number of rows per thread
+ const int dr = (nr + nth - 1) / nth;
+ // row range for this thread
+ const int ir0 = dr * ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ if (src0->type == dst->type &&
+ ne00 == ne0 &&
+ nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
+ // copy by rows
+ const size_t rs = ne00*nb00;
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ memcpy(
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+ rs);
+ }
+ }
+ }
+ return;
+ }
+
+ if (ggml_is_contiguous(dst)) {
+ // TODO: simplify
+ if (nb00 == sizeof(float)) {
+ if (dst->type == GGML_TYPE_F32) {
+ size_t id = 0;
+ const size_t rs = ne00 * nb00;
+ char * dst_ptr = (char *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += rs * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+ memcpy(dst_ptr + id, src0_ptr, rs);
+ id += rs;
+ }
+ id += rs * (ne01 - ir1);
+ }
+ }
+ } else if (type_traits[dst->type].from_float) {
+ ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
+
+ size_t id = 0;
+ size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+ char * dst_ptr = (char *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += rs * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+ quantize_row_q(src0_ptr, dst_ptr + id, ne00);
+ id += rs;
+ }
+ id += rs * (ne01 - ir1);
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ } else {
+ //printf("%s: this is not optimal - fix me\n", __func__);
+
+ if (dst->type == GGML_TYPE_F32) {
+ size_t id = 0;
+ float * dst_ptr = (float *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += ne00 * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = *src0_ptr;
+ id++;
+ }
+ }
+ id += ne00 * (ne01 - ir1);
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F16) {
+ size_t id = 0;
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += ne00 * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+ id++;
+ }
+ }
+ id += ne00 * (ne01 - ir1);
+ }
+ }
+ } else if (dst->type == GGML_TYPE_BF16) {
+ size_t id = 0;
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ id += ne00 * ir0;
+ for (int i01 = ir0; i01 < ir1; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
+ id++;
+ }
+ }
+ id += ne00 * (ne01 - ir1);
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ }
+
+ return;
+ }
+
+ // dst counters
+
+ int64_t i10 = 0;
+ int64_t i11 = 0;
+ int64_t i12 = 0;
+ int64_t i13 = 0;
+
+ if (dst->type == GGML_TYPE_F32) {
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ i10 += ne00 * ir0;
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
+
+ memcpy(dst_ptr, src0_ptr, sizeof(float));
+
+ if (++i10 == ne0) {
+ i10 = 0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ i10 += ne00 * (ne01 - ir1);
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F16) {
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ i10 += ne00 * ir0;
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
+
+ *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
+
+ if (++i10 == ne0) {
+ i10 = 0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ i10 += ne00 * (ne01 - ir1);
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ } else if (dst->type == GGML_TYPE_BF16) {
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ i10 += ne00 * ir0;
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
+
+ *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
+
+ if (++i10 == ne0) {
+ i10 = 0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ i10 += ne00 * (ne01 - ir1);
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+}
+
+// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
+static void ggml_compute_forward_dup_bytes(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+ GGML_ASSERT(src0->type == dst->type);
+
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
+ ggml_compute_forward_dup_same_cont(params, dst);
+ return;
+ }
+
+ GGML_TENSOR_UNARY_OP_LOCALS;
+
+ const size_t type_size = ggml_type_size(src0->type);
+ const int ith = params->ith; // thread index
+ const int nth = params->nth; // number of threads
+
+
+ // parallelize by rows
+ const int nr = ne01;
+ // number of rows per thread
+ const int dr = (nr + nth - 1) / nth;
+ // row range for this thread
+ const int ir0 = dr * ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ if (src0->type == dst->type &&
+ ne00 == ne0 &&
+ nb00 == type_size && nb0 == type_size) {
+ // copy by rows
+ const size_t rs = ne00 * type_size;
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ memcpy(
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+ rs);
+ }
+ }
+ }
+ return;
+ }
+
+ if (ggml_is_contiguous(dst)) {
+ size_t id = 0;
+ char * dst_ptr = (char *) dst->data;
+ const size_t rs = ne00 * type_size;
+
+ if (nb00 == type_size) {
+ // src0 is contigous on first dimension, copy by rows
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ id += rs * ir0;
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+ memcpy(dst_ptr + id, src0_ptr, rs);
+ id += rs;
+ }
+ id += rs * (ne01 - ir1);
+ }
+ }
+ } else {
+ //printf("%s: this is not optimal - fix me\n", __func__);
+
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ id += rs * ir0;
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
+ memcpy(dst_ptr + id, src0_ptr, type_size);
+
+ id += type_size;
+ }
+ }
+ id += rs * (ne01 - ir1);
+ }
+ }
+ }
+
+ return;
+ }
+
+ // dst counters
+
+ int64_t i10 = 0;
+ int64_t i11 = 0;
+ int64_t i12 = 0;
+ int64_t i13 = 0;
+
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ i10 += ne00 * ir0;
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+ char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
+
+ memcpy(dst_ptr, src0_ptr, type_size);
+
+ if (++i10 == ne0) {
+ i10 = 0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+ i10 += ne00 * (ne01 - ir1);
+ while (i10 >= ne0) {
+ i10 -= ne0;
+ if (++i11 == ne1) {
+ i11 = 0;
+ if (++i12 == ne2) {
+ i12 = 0;
+ if (++i13 == ne3) {
+ i13 = 0;
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_dup(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (src0->type == dst->type) {
+ ggml_compute_forward_dup_bytes(params, dst);
+ return;
+ }
+
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_dup_f16(params, dst);
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ ggml_compute_forward_dup_bf16(params, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_dup_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_add
+
+static void ggml_compute_forward_add_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ GGML_ASSERT( nb0 == sizeof(float));
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ if (nb10 == sizeof(float)) {
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
+ const int64_t i03 = ir/(ne02*ne01);
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+ const int64_t nr0 = ne00 / ne10;
+
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
+
+ for (int64_t r = 0; r < nr0; ++r) {
+#ifdef GGML_USE_ACCELERATE
+ vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
+#else
+ ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
+#endif
+ }
+ }
+ } else {
+ // src1 is not contiguous
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
+ const int64_t i03 = ir/(ne02*ne01);
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
+ const int64_t i10 = i0 % ne10;
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
+
+ dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_add_f16_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ if (dst->type == GGML_TYPE_F32) {
+ GGML_ASSERT( nb0 == sizeof(float));
+ }
+ else {
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+ }
+
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ if (nb10 == sizeof(float)) {
+ if (dst->type == GGML_TYPE_F16) {
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0, src1 and dst are same shape => same indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+ for (int i = 0; i < ne0; i++) {
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
+ }
+ }
+ } else {
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0, src1 and dst are same shape => same indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+ for (int i = 0; i < ne0; i++) {
+ dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
+ }
+ }
+ }
+ }
+ else {
+ // src1 is not contiguous
+ GGML_ASSERT(false);
+ }
+}
+
+static void ggml_compute_forward_add_bf16_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ if (dst->type == GGML_TYPE_F32) {
+ GGML_ASSERT( nb0 == sizeof(float));
+ }
+ else {
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
+ }
+
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ if (nb10 == sizeof(float)) {
+ if (dst->type == GGML_TYPE_BF16) {
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0, src1 and dst are same shape => same indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+ for (int i = 0; i < ne0; i++) {
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
+ }
+ }
+ } else {
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0, src1 and dst are same shape => same indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+ for (int i = 0; i < ne0; i++) {
+ dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
+ }
+ }
+ }
+ }
+ else {
+ // src1 is not contiguous
+ GGML_ASSERT(false);
+ }
+}
+
+static void ggml_compute_forward_add_f16_f16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
+
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ if (nb10 == sizeof(ggml_fp16_t)) {
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0, src1 and dst are same shape => same indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+ for (int i = 0; i < ne0; i++) {
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i]));
+ }
+ }
+ }
+ else {
+ // src1 is not contiguous
+ GGML_ASSERT(false);
+ }
+}
+
+static void ggml_compute_forward_add_bf16_bf16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
+
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ if (nb10 == sizeof(ggml_bf16_t)) {
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0, src1 and dst are same shape => same indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+ ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+ for (int i = 0; i < ne0; i++) {
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i]));
+ }
+ }
+ }
+ else {
+ // src1 is not contiguous
+ GGML_ASSERT(false);
+ }
+}
+
+static void ggml_compute_forward_add_q_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const enum ggml_type type = src0->type;
+ const enum ggml_type dtype = dst->type;
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
+ ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float;
+
+ // we don't support permuted src0 or src1
+ GGML_ASSERT(nb00 == ggml_type_size(type));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ GGML_ASSERT(ggml_is_quantized(src0->type));
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 indices
+ const int i03 = ir/(ne02*ne01);
+ const int i02 = (ir - i03*ne02*ne01)/ne01;
+ const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ // src1 and dst are same shape as src0 => same indices
+ const int i13 = i03;
+ const int i12 = i02;
+ const int i11 = i01;
+
+ const int i3 = i03;
+ const int i2 = i02;
+ const int i1 = i01;
+
+ void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+ float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
+ void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
+
+ assert(ne00 % 32 == 0);
+
+ // unquantize row from src0 to temp buffer
+ dequantize_row_q(src0_row, wdata, ne00);
+ // add src1
+ ggml_vec_acc_f32(ne00, wdata, src1_row);
+ // quantize row to dst
+ if (quantize_row_q != NULL) {
+ quantize_row_q(wdata, dst_row, ne00);
+ } else {
+ memcpy(dst_row, wdata, ne0*nb0);
+ }
+ }
+}
+
+static void ggml_compute_forward_add(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ if (src1->type == GGML_TYPE_F32) {
+ ggml_compute_forward_add_f32(params, dst);
+ }
+ else {
+ GGML_ASSERT(false);
+ }
+ } break;
+ case GGML_TYPE_F16:
+ {
+ if (src1->type == GGML_TYPE_F16) {
+ ggml_compute_forward_add_f16_f16(params, dst);
+ }
+ else if (src1->type == GGML_TYPE_F32) {
+ ggml_compute_forward_add_f16_f32(params, dst);
+ }
+ else {
+ GGML_ASSERT(false);
+ }
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ if (src1->type == GGML_TYPE_BF16) {
+ ggml_compute_forward_add_bf16_bf16(params, dst);
+ }
+ else if (src1->type == GGML_TYPE_F32) {
+ ggml_compute_forward_add_bf16_f32(params, dst);
+ }
+ else {
+ GGML_ASSERT(false);
+ }
+ } break;
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_Q4_0_4_4:
+ case GGML_TYPE_Q4_0_4_8:
+ case GGML_TYPE_Q4_0_8_8:
+ {
+ ggml_compute_forward_add_q_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_add1
+
+static void ggml_compute_forward_add1_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_is_scalar(src1));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT( nb0 == sizeof(float));
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 and dst are same shape => same indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+#ifdef GGML_USE_ACCELERATE
+ UNUSED(ggml_vec_add1_f32);
+
+ vDSP_vadd(
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
+ (float *) ((char *) src1->data), 0,
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
+ ne0);
+#else
+ ggml_vec_add1_f32(ne0,
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
+ *(float *) src1->data);
+#endif
+ }
+}
+
+static void ggml_compute_forward_add1_f16_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_is_scalar(src1));
+
+ // scalar to add
+ const float v = *(float *) src1->data;
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
+
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 and dst are same shape => same indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+ for (int i = 0; i < ne0; i++) {
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
+ }
+ }
+}
+
+static void ggml_compute_forward_add1_f16_f16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_is_scalar(src1));
+
+ // scalar to add
+ const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
+
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 and dst are same shape => same indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+ for (int i = 0; i < ne0; i++) {
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
+ }
+ }
+}
+
+static void ggml_compute_forward_add1_q_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_is_scalar(src1));
+
+ // scalar to add
+ const float v = *(float *) src1->data;
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ const enum ggml_type type = src0->type;
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
+ ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
+
+ // we don't support permuted src0
+ GGML_ASSERT(nb00 == ggml_type_size(type));
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ GGML_ASSERT(ggml_is_quantized(src0->type));
+ GGML_ASSERT(dst->type == src0->type);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 and dst are same shape => same indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
+ void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 ));
+
+ assert(ne0 % 32 == 0);
+
+ // unquantize row from src0 to temp buffer
+ dequantize_row_q(src0_row, wdata, ne0);
+ // add src1
+ ggml_vec_acc1_f32(ne0, wdata, v);
+ // quantize row to dst
+ quantize_row_q(wdata, dst_row, ne0);
+ }
+}
+
+static void ggml_compute_forward_add1_bf16_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_is_scalar(src1));
+
+ // scalar to add
+ const float v = *(float *) src1->data;
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
+
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 and dst are same shape => same indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+ for (int i = 0; i < ne0; i++) {
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
+ }
+ }
+}
+
+static void ggml_compute_forward_add1_bf16_bf16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_is_scalar(src1));
+
+ // scalar to add
+ const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT(src0->type == GGML_TYPE_BF16);
+ GGML_ASSERT(src1->type == GGML_TYPE_BF16);
+ GGML_ASSERT(dst->type == GGML_TYPE_BF16);
+
+ GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
+ GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 and dst are same shape => same indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
+ ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+ for (int i = 0; i < ne0; i++) {
+ dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
+ }
+ }
+}
+
+static void ggml_compute_forward_add1(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_add1_f32(params, dst);
+ } break;
+ case GGML_TYPE_F16:
+ {
+ if (src1->type == GGML_TYPE_F16) {
+ ggml_compute_forward_add1_f16_f16(params, dst);
+ }
+ else if (src1->type == GGML_TYPE_F32) {
+ ggml_compute_forward_add1_f16_f32(params, dst);
+ }
+ else {
+ GGML_ASSERT(false);
+ }
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ if (src1->type == GGML_TYPE_BF16) {
+ ggml_compute_forward_add1_bf16_bf16(params, dst);
+ }
+ else if (src1->type == GGML_TYPE_F32) {
+ ggml_compute_forward_add1_bf16_f32(params, dst);
+ }
+ else {
+ GGML_ASSERT(false);
+ }
+ } break;
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_Q4_0_4_4:
+ case GGML_TYPE_Q4_0_4_8:
+ case GGML_TYPE_Q4_0_8_8:
+ {
+ ggml_compute_forward_add1_q_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_acc
+
+static void ggml_compute_forward_acc_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+
+ // view src0 and dst with these strides and data offset inbytes during acc
+ // nb0 is implicitly element_size because src0 and dst are contiguous
+ size_t nb1 = ((int32_t *) dst->op_params)[0];
+ size_t nb2 = ((int32_t *) dst->op_params)[1];
+ size_t nb3 = ((int32_t *) dst->op_params)[2];
+ size_t offset = ((int32_t *) dst->op_params)[3];
+ bool inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+ if (!inplace) {
+ if (params->ith == 0) {
+ // memcpy needs to be synchronized across threads to avoid race conditions.
+ // => do it in INIT phase
+ memcpy(
+ ((char *) dst->data),
+ ((char *) src0->data),
+ ggml_nbytes(dst));
+ }
+ ggml_barrier(params->shared);
+ }
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src1);
+ const int nc = src1->ne[0];
+
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
+
+ // src0 and dst as viewed during acc
+ const size_t nb0 = ggml_element_size(src0);
+
+ const size_t nb00 = nb0;
+ const size_t nb01 = nb1;
+ const size_t nb02 = nb2;
+ const size_t nb03 = nb3;
+
+ GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < ggml_nbytes(dst));
+ GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
+
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 and dst are viewed with shape of src1 and offset
+ // => same indices
+ const int i3 = ir/(ne12*ne11);
+ const int i2 = (ir - i3*ne12*ne11)/ne11;
+ const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
+
+#ifdef GGML_USE_ACCELERATE
+ vDSP_vadd(
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
+ (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc);
+#else
+ ggml_vec_add_f32(nc,
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
+ (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+#endif
+ }
+}
+
+static void ggml_compute_forward_acc(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_acc_f32(params, dst);
+ } break;
+ case GGML_TYPE_F16:
+ case GGML_TYPE_BF16:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_Q4_0_4_4:
+ case GGML_TYPE_Q4_0_4_8:
+ case GGML_TYPE_Q4_0_8_8:
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sub
+
+static void ggml_compute_forward_sub_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ GGML_ASSERT( nb0 == sizeof(float));
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ if (nb10 == sizeof(float)) {
+ for (int ir = 0; ir < nr; ++ir) {
+ // src0, src1 and dst are same shape => same indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+#ifdef GGML_USE_ACCELERATE
+ vDSP_vsub(
+ (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
+ ne0);
+#else
+ ggml_vec_sub_f32(ne0,
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
+ (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+#endif
+ // }
+ // }
+ }
+ } else {
+ // src1 is not contiguous
+ for (int ir = 0; ir < nr; ++ir) {
+ // src0, src1 and dst are same shape => same indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
+ float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+ for (int i0 = 0; i0 < ne0; i0++) {
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+
+ dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_sub(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sub_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_mul
+
+static void ggml_compute_forward_mul_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ if (ggml_nelements(dst->src[1]) == 1 && ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst) &&
+ dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ int64_t nelements = ggml_nelements(dst->src[0]);
+ int64_t n_per_thread = (nelements + nth - 1)/nth;
+ n_per_thread = MAX(1024, n_per_thread);
+ int64_t start = n_per_thread*ith;
+ if (start >= nelements) return;
+ int64_t end = MIN(nelements, start + n_per_thread);
+ const float * src = (const float *)dst->src[0]->data + start;
+ float * res = (float *)dst->data + start;
+ if (res != src) {
+ memcpy(res, src, (end - start)*sizeof(float));
+ }
+ ggml_vec_scale_f32(end - start, res, *(const float *)dst->src[1]->data);
+ return;
+ }
+
+ const int64_t nr = ggml_nrows(src0);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ GGML_ASSERT( nb0 == sizeof(float));
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ if (nb10 == sizeof(float)) {
+ for (int64_t ir = ith; ir < nr; ir += nth) {
+ // src0 and dst are same shape => same indices
+ const int64_t i03 = ir/(ne02*ne01);
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+ const int64_t nr0 = ne00 / ne10;
+
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
+
+ for (int64_t r = 0 ; r < nr0; ++r) {
+#ifdef GGML_USE_ACCELERATE
+ UNUSED(ggml_vec_mul_f32);
+
+ vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
+#else
+ ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
+#endif
+ }
+ }
+ } else {
+ // src1 is not contiguous
+ for (int64_t ir = ith; ir < nr; ir += nth) {
+ // src0 and dst are same shape => same indices
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
+ const int64_t i03 = ir/(ne02*ne01);
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+ for (int64_t i0 = 0; i0 < ne00; ++i0) {
+ const int64_t i10 = i0 % ne10;
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
+
+ dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_mul(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now");
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_mul_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_div
+
+static void ggml_compute_forward_div_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t nr = ggml_nrows(src0);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ GGML_ASSERT( nb0 == sizeof(float));
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ if (nb10 == sizeof(float)) {
+ for (int64_t ir = ith; ir < nr; ir += nth) {
+ // src0 and dst are same shape => same indices
+ const int64_t i03 = ir/(ne02*ne01);
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+ const int64_t nr0 = ne00 / ne10;
+
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
+
+ for (int64_t r = 0; r < nr0; ++r) {
+#ifdef GGML_USE_ACCELERATE
+ UNUSED(ggml_vec_div_f32);
+
+ vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
+#else
+ ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
+#endif
+ }
+ }
+ } else {
+ // src1 is not contiguous
+ for (int64_t ir = ith; ir < nr; ir += nth) {
+ // src0 and dst are same shape => same indices
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
+ const int64_t i03 = ir/(ne02*ne01);
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+ for (int64_t i0 = 0; i0 < ne00; ++i0) {
+ const int64_t i10 = i0 % ne10;
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
+
+ dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_div(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_div_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sqr
+
+static void ggml_compute_forward_sqr_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_sqr_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_sqr(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sqr_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sqrt
+
+static void ggml_compute_forward_sqrt_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_sqrt_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_sqrt(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sqrt_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_log
+
+static void ggml_compute_forward_log_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_log_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_log(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_log_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sum
+
+static void ggml_compute_forward_sum_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_scalar(dst));
+
+
+ assert(ggml_is_scalar(dst));
+ assert(src0->nb[0] == sizeof(float));
+
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
+
+ ggml_float sum = 0;
+ ggml_float row_sum = 0;
+
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
+ ggml_vec_sum_f32_ggf(ne00,
+ &row_sum,
+ (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+ sum += row_sum;
+ }
+ }
+ }
+ ((float *) dst->data)[0] = sum;
+}
+
+static void ggml_compute_forward_sum_f16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_scalar(dst));
+
+ assert(src0->nb[0] == sizeof(ggml_fp16_t));
+
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
+
+ float sum = 0;
+ float row_sum = 0;
+
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
+ ggml_vec_sum_f16_ggf(ne00,
+ &row_sum,
+ (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
+ sum += row_sum;
+ }
+ }
+ }
+ ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
+}
+
+static void ggml_compute_forward_sum_bf16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_scalar(dst));
+
+ assert(src0->nb[0] == sizeof(ggml_bf16_t));
+
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
+
+ float sum = 0;
+ float row_sum = 0;
+
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
+ ggml_vec_sum_bf16_ggf(ne00,
+ &row_sum,
+ (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
+ sum += row_sum;
+ }
+ }
+ }
+ ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
+}
+
+static void ggml_compute_forward_sum(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sum_f32(params, dst);
+ } break;
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_sum_f16(params, dst);
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ ggml_compute_forward_sum_bf16(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sum_rows
+
+static void ggml_compute_forward_sum_rows_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(dst->nb[0] == sizeof(float));
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT(ne0 == 1);
+ GGML_ASSERT(ne1 == ne01);
+ GGML_ASSERT(ne2 == ne02);
+ GGML_ASSERT(ne3 == ne03);
+
+ for (int64_t i3 = 0; i3 < ne03; i3++) {
+ for (int64_t i2 = 0; i2 < ne02; i2++) {
+ for (int64_t i1 = 0; i1 < ne01; i1++) {
+ float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
+ float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
+ float row_sum = 0;
+ ggml_vec_sum_f32(ne00, &row_sum, src_row);
+ dst_row[0] = row_sum;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_sum_rows(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sum_rows_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_mean
+
+static void ggml_compute_forward_mean_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(src0->nb[0] == sizeof(float));
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ assert(ne0 == 1);
+ assert(ne1 == ne01);
+ assert(ne2 == ne02);
+ assert(ne3 == ne03);
+
+ UNUSED(ne0);
+ UNUSED(ne1);
+ UNUSED(ne2);
+ UNUSED(ne3);
+
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
+ ggml_vec_sum_f32(ne00,
+ (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
+ (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+
+ *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_mean(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_mean_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_argmax
+
+static void ggml_compute_forward_argmax_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(src0->nb[0] == sizeof(float));
+ assert(dst->nb[0] == sizeof(float));
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+
+ const size_t nb01 = src0->nb[1];
+ const size_t nb0 = dst->nb[0];
+
+ for (int64_t i1 = 0; i1 < ne01; i1++) {
+ float * src = (float *) ((char *) src0->data + i1*nb01);
+ int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0);
+ int v = 0;
+ ggml_vec_argmax_f32(ne00, &v, src);
+ dst_[0] = v;
+ }
+}
+
+static void ggml_compute_forward_argmax(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_argmax_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_repeat
+
+static void ggml_compute_forward_repeat_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ GGML_ASSERT(ggml_can_repeat(src0, dst));
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ // guaranteed to be an integer due to the check in ggml_can_repeat
+ const int nr0 = (int)(ne0/ne00);
+ const int nr1 = (int)(ne1/ne01);
+ const int nr2 = (int)(ne2/ne02);
+ const int nr3 = (int)(ne3/ne03);
+
+ // TODO: support for transposed / permuted tensors
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ // TODO: maybe this is not optimal?
+ for (int i3 = 0; i3 < nr3; i3++) {
+ for (int k3 = 0; k3 < ne03; k3++) {
+ for (int i2 = 0; i2 < nr2; i2++) {
+ for (int k2 = 0; k2 < ne02; k2++) {
+ for (int i1 = 0; i1 < nr1; i1++) {
+ for (int k1 = 0; k1 < ne01; k1++) {
+ for (int i0 = 0; i0 < nr0; i0++) {
+ ggml_vec_cpy_f32(ne00,
+ (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0),
+ (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01));
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_repeat_f16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ GGML_ASSERT(ggml_can_repeat(src0, dst));
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ // guaranteed to be an integer due to the check in ggml_can_repeat
+ const int nr0 = (int)(ne0/ne00);
+ const int nr1 = (int)(ne1/ne01);
+ const int nr2 = (int)(ne2/ne02);
+ const int nr3 = (int)(ne3/ne03);
+
+ // TODO: support for transposed / permuted tensors
+ GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+ // TODO: maybe this is not optimal?
+ for (int i3 = 0; i3 < nr3; i3++) {
+ for (int k3 = 0; k3 < ne03; k3++) {
+ for (int i2 = 0; i2 < nr2; i2++) {
+ for (int k2 = 0; k2 < ne02; k2++) {
+ for (int i1 = 0; i1 < nr1; i1++) {
+ for (int k1 = 0; k1 < ne01; k1++) {
+ for (int i0 = 0; i0 < nr0; i0++) {
+ ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
+ ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
+ // ggml_vec_cpy_f16(ne00, y, x)
+ for (int i = 0; i < ne00; ++i) {
+ y[i] = x[i];
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_repeat(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_BF16:
+ case GGML_TYPE_I16:
+ {
+ ggml_compute_forward_repeat_f16(params, dst);
+ } break;
+ case GGML_TYPE_F32:
+ case GGML_TYPE_I32:
+ {
+ ggml_compute_forward_repeat_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_repeat_back
+
+static void ggml_compute_forward_repeat_back_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ GGML_ASSERT(ggml_can_repeat(dst, src0));
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ // guaranteed to be an integer due to the check in ggml_can_repeat
+ const int nr0 = (int)(ne00/ne0);
+ const int nr1 = (int)(ne01/ne1);
+ const int nr2 = (int)(ne02/ne2);
+ const int nr3 = (int)(ne03/ne3);
+
+ // TODO: support for transposed / permuted tensors
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ if (ggml_is_contiguous(dst)) {
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+ } else {
+ for (int k3 = 0; k3 < ne3; k3++) {
+ for (int k2 = 0; k2 < ne2; k2++) {
+ for (int k1 = 0; k1 < ne1; k1++) {
+ ggml_vec_set_f32(ne0,
+ (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
+ 0);
+ }
+ }
+ }
+ }
+
+ // TODO: maybe this is not optimal?
+ for (int i3 = 0; i3 < nr3; i3++) {
+ for (int k3 = 0; k3 < ne3; k3++) {
+ for (int i2 = 0; i2 < nr2; i2++) {
+ for (int k2 = 0; k2 < ne2; k2++) {
+ for (int i1 = 0; i1 < nr1; i1++) {
+ for (int k1 = 0; k1 < ne1; k1++) {
+ for (int i0 = 0; i0 < nr0; i0++) {
+ ggml_vec_acc_f32(ne0,
+ (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
+ (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_repeat_back(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_repeat_back_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_concat
+
+static void ggml_compute_forward_concat_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ // TODO: support for transposed / permuted tensors
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb00 == sizeof(float));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
+
+ GGML_ASSERT(dim >= 0 && dim < 4);
+
+ int64_t o[4] = {0, 0, 0, 0};
+ o[dim] = src0->ne[dim];
+
+ const float * x;
+
+ // TODO: smarter multi-theading
+ for (int i3 = 0; i3 < ne3; i3++) {
+ for (int i2 = ith; i2 < ne2; i2 += nth) {
+ for (int i1 = 0; i1 < ne1; i1++) {
+ for (int i0 = 0; i0 < ne0; i0++) {
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
+ } else {
+ x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
+ }
+
+ float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
+
+ *y = *x;
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_concat(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_I32:
+ {
+ ggml_compute_forward_concat_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_abs
+
+static void ggml_compute_forward_abs_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_abs_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_abs(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_abs_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sgn
+
+static void ggml_compute_forward_sgn_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_sgn_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_sgn(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sgn_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_neg
+
+static void ggml_compute_forward_neg_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_neg_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_neg(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_neg_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_step
+
+static void ggml_compute_forward_step_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_step_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_step(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_step_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_tanh
+
+static void ggml_compute_forward_tanh_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_tanh_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_tanh(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_tanh_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_elu
+
+static void ggml_compute_forward_elu_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_elu_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_elu(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_elu_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_relu
+
+static void ggml_compute_forward_relu_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_relu_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_relu(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_relu_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_sigmoid
+
+static void ggml_compute_forward_sigmoid_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_sigmoid_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_sigmoid(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sigmoid_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_gelu
+
+static void ggml_compute_forward_gelu_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ ggml_vec_gelu_f32(nc,
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
+ (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+ for (int k = 0; k < nc; k++) {
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+ UNUSED(x);
+ assert(!isnan(x));
+ assert(!isinf(x));
+ }
+#endif
+ }
+}
+
+static void ggml_compute_forward_gelu(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_gelu_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_gelu_quick
+
+static void ggml_compute_forward_gelu_quick_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ ggml_vec_gelu_quick_f32(nc,
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
+ (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+ for (int k = 0; k < nc; k++) {
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+ UNUSED(x);
+ assert(!isnan(x));
+ assert(!isinf(x));
+ }
+#endif
+ }
+}
+
+static void ggml_compute_forward_gelu_quick(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_gelu_quick_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_silu
+
+static void ggml_compute_forward_silu_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ ggml_vec_silu_f32(nc,
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
+ (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+ for (int k = 0; k < nc; k++) {
+ const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
+ UNUSED(x);
+ assert(!isnan(x));
+ assert(!isinf(x));
+ }
+#endif
+ }
+}
+
+static void ggml_compute_forward_silu(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_silu_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+// ggml_compute_forward_leaky_relu
+
+static void ggml_compute_forward_leaky_relu_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ float negative_slope;
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
+
+ assert(dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_leaky_relu_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
+ }
+}
+
+static void ggml_compute_forward_leaky_relu(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_leaky_relu_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_silu_back
+
+static void ggml_compute_forward_silu_back_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * grad = dst->src[1];
+
+ assert(ggml_is_contiguous_1(grad));
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+ assert(ggml_are_same_shape(src0, grad));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ ggml_vec_silu_backward_f32(nc,
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
+ (float *) ((char *) src0->data + i1*(src0->nb[1])),
+ (float *) ((char *) grad->data + i1*(grad->nb[1])));
+
+#ifndef NDEBUG
+ for (int k = 0; k < nc; k++) {
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+ UNUSED(x);
+ assert(!isnan(x));
+ assert(!isinf(x));
+ }
+#endif
+ }
+}
+
+static void ggml_compute_forward_silu_back(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_silu_back_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+
+static void ggml_compute_forward_hardswish_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_hardswish_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+static void ggml_compute_forward_hardswish(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_hardswish_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+static void ggml_compute_forward_hardsigmoid_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_hardsigmoid_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_hardsigmoid(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_hardsigmoid_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+
+// ggml_compute_forward_norm
+
+static void ggml_compute_forward_norm_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ GGML_ASSERT(eps > 0.0f);
+
+ // TODO: optimize
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+ const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+ ggml_float sum = 0.0;
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ sum += (ggml_float)x[i00];
+ }
+
+ float mean = sum/ne00;
+
+ float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+ ggml_float sum2 = 0.0;
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ float v = x[i00] - mean;
+ y[i00] = v;
+ sum2 += (ggml_float)(v*v);
+ }
+
+ float variance = sum2/ne00;
+ const float scale = 1.0f/sqrtf(variance + eps);
+
+ ggml_vec_scale_f32(ne00, y, scale);
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_norm(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_norm_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_group_rms_norm
+
+static void ggml_compute_forward_rms_norm_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ GGML_ASSERT(eps > 0.0f);
+
+ // TODO: optimize
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+ const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+ ggml_float sum = 0.0;
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ sum += (ggml_float)(x[i00] * x[i00]);
+ }
+
+ const float mean = sum/ne00;
+
+ float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+ memcpy(y, x, ne00 * sizeof(float));
+ // for (int i00 = 0; i00 < ne00; i00++) {
+ // y[i00] = x[i00];
+ // }
+
+ const float scale = 1.0f/sqrtf(mean + eps);
+
+ ggml_vec_scale_f32(ne00, y, scale);
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_rms_norm(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_rms_norm_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+static void ggml_compute_forward_rms_norm_back_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ // TODO: optimize
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+ // src1 is same shape as src0 => same indices
+ const int64_t i11 = i01;
+ const int64_t i12 = i02;
+ const int64_t i13 = i03;
+
+ const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+ const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
+
+ ggml_float sum_xx = 0.0;
+ ggml_float sum_xdz = 0.0;
+
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ sum_xx += (ggml_float)(x[i00] * x[i00]);
+ sum_xdz += (ggml_float)(x[i00] * dz[i00]);
+ }
+
+ //const float mean = (float)(sum_xx)/ne00;
+ const float mean_eps = (float)(sum_xx)/ne00 + eps;
+ const float sum_eps = (float)(sum_xx) + eps*ne00;
+ //const float mean_xdz = (float)(sum_xdz)/ne00;
+ // we could cache rms from forward pass to improve performance.
+ // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
+ //const float rms = sqrtf(mean_eps);
+ const float rrms = 1.0f / sqrtf(mean_eps);
+ //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
+
+ {
+ // z = rms_norm(x)
+ //
+ // rms_norm(src0) =
+ // scale(
+ // src0,
+ // div(
+ // 1,
+ // sqrt(
+ // add(
+ // scale(
+ // sum(
+ // sqr(
+ // src0)),
+ // (1.0/N)),
+ // eps))));
+
+ // postorder:
+ // ## op args grad
+ // 00 param src0 grad[#00]
+ // 01 const 1
+ // 02 sqr (#00) grad[#02]
+ // 03 sum (#02) grad[#03]
+ // 04 const 1/N
+ // 05 scale (#03, #04) grad[#05]
+ // 06 const eps
+ // 07 add (#05, #06) grad[#07]
+ // 08 sqrt (#07) grad[#08]
+ // 09 div (#01,#08) grad[#09]
+ // 10 scale (#00,#09) grad[#10]
+ //
+ // backward pass, given grad[#10]
+ // #10: scale
+ // grad[#00] += scale(grad[#10],#09)
+ // grad[#09] += sum(mul(grad[#10],#00))
+ // #09: div
+ // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
+ // #08: sqrt
+ // grad[#07] += mul(grad[#08], div(0.5, #08))
+ // #07: add
+ // grad[#05] += grad[#07]
+ // #05: scale
+ // grad[#03] += scale(grad[#05],#04)
+ // #03: sum
+ // grad[#02] += repeat(grad[#03], #02)
+ // #02:
+ // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
+ //
+ // substitute and simplify:
+ // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
+ // grad[#02] = repeat(grad[#03], #02)
+ // grad[#02] = repeat(scale(grad[#05],#04), #02)
+ // grad[#02] = repeat(scale(grad[#07],#04), #02)
+ // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
+ // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
+ // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
+ // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
+ // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
+ // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
+ // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
+ // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
+ // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
+ // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
+ // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
+ // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
+ // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
+ // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
+ // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
+ // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
+ // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
+ // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
+ // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
+ // a = b*c + d*e
+ // a = b*c*f/f + d*e*f/f
+ // a = (b*c*f + d*e*f)*(1/f)
+ // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
+ // a = (b + d*e/c)*c
+ // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
+ // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
+ // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
+ // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
+ // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
+ // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
+ // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
+ // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
+ // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
+ // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
+ }
+ // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
+ // post-order:
+ // dx := x
+ // dx := scale(dx,-mean_xdz/mean_eps)
+ // dx := add(dx, dz)
+ // dx := scale(dx, rrms)
+ float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+ ggml_vec_cpy_f32 (ne00, dx, x);
+ // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
+ ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
+ ggml_vec_acc_f32 (ne00, dx, dz);
+ ggml_vec_scale_f32(ne00, dx, rrms);
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_rms_norm_back(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_rms_norm_back_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_group_norm
+
+static void ggml_compute_forward_group_norm_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ const float eps = 1e-6f; // TODO: make this a parameter
+
+ // TODO: optimize
+
+ int n_channels = src0->ne[2];
+ int n_groups = dst->op_params[0];
+ int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
+ for (int i = ith; i < n_groups; i += nth) {
+ int start = i * n_channels_per_group;
+ int end = start + n_channels_per_group;
+ if (end > n_channels) {
+ end = n_channels;
+ }
+ int step = end - start;
+
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ ggml_float sum = 0.0;
+ for (int64_t i02 = start; i02 < end; i02++) {
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
+ const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
+
+ ggml_float sumr = 0.0;
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ sumr += (ggml_float)x[i00];
+ }
+ sum += sumr;
+ }
+ }
+ const float mean = sum / (ne00 * ne01 * step);
+
+ ggml_float sum2 = 0.0;
+ for (int64_t i02 = start; i02 < end; i02++) {
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
+ const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
+
+ float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
+
+ ggml_float sumr = 0.0;
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ float v = x[i00] - mean;
+ y[i00] = v;
+ sumr += (ggml_float)(v * v);
+ }
+ sum2 += sumr;
+ }
+ }
+ const float variance = sum2 / (ne00 * ne01 * step);
+ const float scale = 1.0f / sqrtf(variance + eps);
+
+ for (int64_t i02 = start; i02 < end; i02++) {
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
+ float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
+ ggml_vec_scale_f32(ne00, y, scale);
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_group_norm(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_group_norm_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_mul_mat
+
+static void ggml_compute_forward_mul_mat_one_chunk(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst,
+ const int64_t num_rows_per_vec_dot,
+ const int64_t ir0_start,
+ const int64_t ir0_end,
+ const int64_t ir1_start,
+ const int64_t ir1_end) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const enum ggml_type type = src0->type;
+
+ const bool src1_cont = ggml_is_contiguous(src1);
+
+ ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
+ enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
+
+ // broadcast factors
+ const int64_t r2 = ne12 / ne02;
+ const int64_t r3 = ne13 / ne03;
+
+ //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
+
+ // threads with no work simply yield (not sure if it helps)
+ if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
+ return;
+ }
+
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
+
+ assert(ne12 % ne02 == 0);
+ assert(ne13 % ne03 == 0);
+
+ // block-tiling attempt
+ const int64_t blck_0 = 16;
+ const int64_t blck_1 = 16;
+
+ const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
+
+ // attempt to reduce false-sharing (does not seem to make a difference)
+ // 16 * 2, accounting for mmla kernels
+ float tmp[32];
+
+ for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
+ for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
+ const int64_t i13 = (ir1 / (ne12 * ne1));
+ const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
+ const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
+
+ // broadcast src0 into src1
+ const int64_t i03 = i13 / r3;
+ const int64_t i02 = i12 / r2;
+
+ const int64_t i1 = i11;
+ const int64_t i2 = i12;
+ const int64_t i3 = i13;
+
+ const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
+
+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
+ // the original src1 data pointer, so we should index using the indices directly
+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
+ const char * src1_col = (const char*)wdata +
+ (src1_cont || src1->type != vec_dot_type
+ ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
+ : (i11 * nb11 + i12 * nb12 + i13 * nb13));
+ float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
+
+ //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
+ // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
+ //}
+
+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
+ vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
+ }
+
+ for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
+ memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_mul_mat(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const enum ggml_type type = src0->type;
+
+ enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
+ ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
+ ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat;
+ int64_t const vec_dot_num_rows = type_traits[type].nrows;
+ int64_t const matmul_num_cols = type_traits[type].ncols;
+ int64_t const blck_size_interleave = type_traits[type].blck_size_interleave;
+ ggml_gemv_t const gemv = type_traits[type].gemv;
+ ggml_gemm_t const gemm = type_traits[type].gemm;
+
+ GGML_ASSERT(ne0 == ne01);
+ GGML_ASSERT(ne1 == ne11);
+ GGML_ASSERT(ne2 == ne12);
+ GGML_ASSERT(ne3 == ne13);
+
+ // we don't support permuted src0 or src1
+ GGML_ASSERT(nb00 == ggml_type_size(type));
+ GGML_ASSERT(nb10 == ggml_type_size(src1->type));
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ // nb01 >= nb00 - src0 is not transposed
+ // compute by src0 rows
+
+#if GGML_USE_IQK_MULMAT || GGML_USE_LLAMAFILE
+ // broadcast factors
+ const int64_t r2 = ne12 / ne02;
+ const int64_t r3 = ne13 / ne03;
+#endif
+
+#if GGML_USE_IQK_MULMAT
+ if (dst->type == GGML_TYPE_F32 && (ne12*ne13)%nth == 0) {
+ int counter = 0;
+ for (int64_t i13 = 0; i13 < ne13; i13++) {
+ for (int64_t i12 = 0; i12 < ne12; i12++) {
+ if (counter++ % nth == ith) {
+ if (!iqk_mul_mat(ne01, ne11, ne00,
+ src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
+ src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
+ (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
+ 0, 1)) goto IQK_MulMat_Not_Available1;
+ }
+ }
+ }
+ return;
+ }
+ if (dst->type == GGML_TYPE_F32) {
+ for (int64_t i13 = 0; i13 < ne13; i13++)
+ for (int64_t i12 = 0; i12 < ne12; i12++)
+ if (!iqk_mul_mat(ne01, ne11, ne00,
+ src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
+ src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
+ (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
+ ith, nth)) goto IQK_MulMat_Not_Available1;
+ return;
+ }
+IQK_MulMat_Not_Available1:;
+#endif
+
+#if GGML_USE_LLAMAFILE
+
+ const bool src1_cont = ggml_is_contiguous(src1);
+
+ if (src1_cont) {
+ for (int64_t i13 = 0; i13 < ne13; i13++)
+ for (int64_t i12 = 0; i12 < ne12; i12++)
+ if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
+ (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
+ nb01/ggml_type_size(src0->type),
+ (const char *)src1->data + i12*nb12 + i13*nb13,
+ nb11/ggml_type_size(src1->type),
+ (char *)dst->data + i12*nb2 + i13*nb3,
+ nb1/ggml_type_size(dst->type),
+ ith, nth,
+ src0->type,
+ src1->type,
+ dst->type))
+ goto UseGgmlGemm1;
+ return;
+ }
+UseGgmlGemm1:;
+#endif
+
+ if (src1->type != vec_dot_type) {
+ char * wdata = params->wdata;
+
+ const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
+ const size_t nbw2 = nbw1*ne11;
+ const size_t nbw3 = nbw2*ne12;
+
+ assert(params->wsize >= ne13*nbw3);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
+ int64_t i11_processed = 0;
+ if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
+ for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
+ from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
+ (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
+ 4, ne10, blck_size_interleave);
+ }
+ i11_processed = ne11 - ne11 % 4;
+ }
+ for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
+ from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
+ (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
+ ne10);
+ }
+ }
+ }
+ }
+
+ if (ith == 0) {
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
+ atomic_store(&params->shared->current_chunk, nth);
+ }
+
+ ggml_barrier(params->shared);
+
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+
+#if GGML_USE_IQK_MULMAT
+ if (src1->type != vec_dot_type && dst->type == GGML_TYPE_F32) {
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
+ for (int64_t i13 = 0; i13 < ne13; i13++)
+ for (int64_t i12 = 0; i12 < ne12; i12++)
+ if (!iqk_mul_mat(ne01, ne11, ne00,
+ src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
+ vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type),
+ (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
+ ith, nth)) goto IQK_MulMat_Not_Available2;
+ return;
+ }
+IQK_MulMat_Not_Available2:;
+#endif
+
+#if GGML_USE_LLAMAFILE
+ if (src1->type != vec_dot_type) {
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
+
+ for (int64_t i13 = 0; i13 < ne13; i13++)
+ for (int64_t i12 = 0; i12 < ne12; i12++)
+ if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
+ (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
+ nb01/ggml_type_size(src0->type),
+ (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
+ row_size/ggml_type_size(vec_dot_type),
+ (char *)dst->data + i12*nb2 + i13*nb3,
+ nb1/ggml_type_size(dst->type),
+ ith, nth,
+ src0->type,
+ vec_dot_type,
+ dst->type))
+ goto UseGgmlGemm2;
+ return;
+ }
+UseGgmlGemm2:;
+#endif
+
+ // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
+ const int64_t nr0 = ne0;
+
+ // This is the size of the rest of the dimensions of the result
+ const int64_t nr1 = ne1 * ne2 * ne3;
+
+ // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
+ int64_t num_rows_per_vec_dot = vec_dot_num_rows;
+ // TODO: currently the mmla kernels support only even numbered rows/cols.
+ // this check can be removed once they are extended to support odd numbered rows/cols too
+ if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
+ num_rows_per_vec_dot = 1;
+ }
+
+ // Now select a reasonable chunk size.
+ int chunk_size = 16;
+
+ // We need to step up the size if it's small
+ if (nr0 == 1 || nr1 == 1) {
+ chunk_size = 64;
+ }
+
+ // distribute the work across the inner or outer loop based on which one is larger
+ // The number of chunks in the 0/1 dim.
+ // CEIL(nr0/chunk_size)
+ int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
+ int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
+
+ // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
+ // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915
+ // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
+ if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
+ // distribute the thread work across the inner or outer loop based on which one is larger
+ nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
+ nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
+ }
+
+ // The number of elements in each chunk
+ const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
+ const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
+
+ if ((ggml_n_dims(src0) == 2) && gemv) {
+ const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+ const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11;
+ int64_t src0_start = (ith * ne01) / nth;
+ int64_t src0_end = ((ith + 1) * ne01) / nth;
+ src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
+ src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
+ if (src0_start >= src0_end) return;
+
+ // If there are more than three rows in src1, use gemm; otherwise, use gemv.
+ if (gemm && (ne11 > 3)) {
+ gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
+ (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
+ }
+ for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) {
+ gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
+ (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
+ src0_end - src0_start);
+ }
+ return;
+ }
+
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
+ int current_chunk = ith;
+
+ while (current_chunk < nchunk0 * nchunk1) {
+ const int64_t ith0 = current_chunk % nchunk0;
+ const int64_t ith1 = current_chunk / nchunk0;
+
+ const int64_t ir0_start = dr0 * ith0;
+ const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
+
+ const int64_t ir1_start = dr1 * ith1;
+ const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
+
+ ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
+
+ if (nth >= nchunk0 * nchunk1) {
+ break;
+ }
+
+ current_chunk = atomic_fetch_add(&params->shared->current_chunk, 1);
+ }
+}
+
+// ggml_compute_forward_mul_mat_id
+
+static void ggml_compute_forward_mul_mat_id(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+ const struct ggml_tensor * ids = dst->src[2];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const enum ggml_type type = src0->type;
+
+ const bool src1_cont = ggml_is_contiguous(src1);
+
+ ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
+ enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
+ ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
+ int64_t const matmul_num_cols = type_traits[type].ncols;
+ ggml_gemv_t const gemv = type_traits[type].gemv;
+
+ // we don't support permuted src0 or src1
+ GGML_ASSERT(nb00 == ggml_type_size(type));
+ GGML_ASSERT(nb10 == ggml_type_size(src1->type));
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ // row groups
+ const int n_ids = ids->ne[0]; // n_expert_used
+ const int n_as = ne02; // n_expert
+
+ char * wdata_src1_end = (src1->type == vec_dot_type) ?
+ (char *) params->wdata :
+ (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
+
+ struct mmid_row_mapping {
+ int32_t i1;
+ int32_t i2;
+ };
+
+ int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
+ struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
+
+ if (src1->type != vec_dot_type) {
+ char * wdata = params->wdata;
+
+ const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
+ const size_t nbw2 = nbw1*ne11;
+ const size_t nbw3 = nbw2*ne12;
+
+ assert(params->wsize >= ne13*nbw3);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
+ for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
+ from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
+ (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
+ ne10);
+ }
+ }
+ }
+ }
+
+#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
+
+ if (ith == 0) {
+ // initialize matrix_row_counts
+ memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
+
+ // group rows by src0 matrix
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
+ for (int id = 0; id < n_ids; ++id) {
+ const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
+
+ assert(i02 >= 0 && i02 < n_as);
+
+ MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
+ matrix_row_counts[i02] += 1;
+ }
+ }
+ }
+
+ ggml_barrier(params->shared);
+
+ // compute each matrix multiplication in sequence
+ for (int cur_a = 0; cur_a < n_as; ++cur_a) {
+ const int64_t cne1 = matrix_row_counts[cur_a];
+
+ if (cne1 == 0) {
+ continue;
+ }
+
+ const char * src0_cur = (const char *) src0->data + cur_a*nb02;
+
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
+
+ const int64_t nr0 = ne01; // src0 rows
+ const int64_t nr1 = cne1; // src1 rows
+ //
+#if GGML_USE_IQK_MULMAT
+ if (ne13 == 1 && dst->type == GGML_TYPE_F32) {
+ if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11,
+ src0->type, (const char *)src0_cur, nb01/ggml_type_size(src0->type),
+ vec_dot_type, (const char *)wdata, row_size/ggml_type_size(vec_dot_type),
+ (float *)dst->data, nb1, nb2,
+ matrix_rows + cur_a*ne12, ith, nth)) goto IQK_MulMat_Not_Available;
+ continue;
+ }
+IQK_MulMat_Not_Available:;
+#endif
+
+ if (((ggml_n_dims(src0) - 1) == 2) && gemv) {
+ int64_t src0_cur_start = (ith * ne01) / nth;
+ int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
+ src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start;
+ src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end;
+ if (src0_cur_start >= src0_cur_end) return;
+
+ for (int ir1 = 0; ir1 < nr1; ir1++) {
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
+ const int id = row_mapping.i1; // selected expert index
+
+ const int64_t i11 = id % ne11;
+ const int64_t i12 = row_mapping.i2; // row index in src1
+
+ const int64_t i1 = id; // selected expert index
+ const int64_t i2 = i12; // row
+
+ const char * src1_col = (const char *) wdata +
+ (src1_cont || src1->type != vec_dot_type
+ ? (i11 + i12 * ne11) * row_size
+ : (i11 * nb11 + i12 * nb12));
+
+ gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
+ (const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
+ }
+ continue;
+ }
+
+ // distribute the thread work across the inner or outer loop based on which one is larger
+
+ const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
+ const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
+
+ const int64_t ith0 = ith % nth0;
+ const int64_t ith1 = ith / nth0;
+
+ const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
+ const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
+
+ const int64_t ir010 = dr0*ith0;
+ const int64_t ir011 = MIN(ir010 + dr0, nr0);
+
+ const int64_t ir110 = dr1*ith1;
+ const int64_t ir111 = MIN(ir110 + dr1, nr1);
+
+ // threads with no work simply yield (not sure if it helps)
+ //if (ir010 >= ir011 || ir110 >= ir111) {
+ // sched_yield();
+ // continue;
+ //}
+
+ // block-tiling attempt
+ const int64_t blck_0 = 16;
+ const int64_t blck_1 = 16;
+
+ // attempt to reduce false-sharing (does not seem to make a difference)
+ float tmp[16];
+
+ for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
+ for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
+ const int64_t _i12 = ir1; // logical row index for this expert
+
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
+ const int id = row_mapping.i1; // selected expert index
+
+ const int64_t i11 = id % ne11;
+ const int64_t i12 = row_mapping.i2; // row index in src1
+
+ const int64_t i1 = id; // selected expert index
+ const int64_t i2 = i12; // row
+
+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
+ // the original src1 data pointer, so we should index using the indices directly
+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
+ const char * src1_col = (const char *) wdata +
+ (src1_cont || src1->type != vec_dot_type
+ ? (i11 + i12*ne11)*row_size
+ : (i11*nb11 + i12*nb12));
+
+ float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
+
+ //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
+ // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
+ //}
+
+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
+ vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
+ }
+
+ memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
+ }
+ }
+ }
+ }
+
+#undef MMID_MATRIX_ROW
+}
+
+// ggml_compute_forward_out_prod
+
+static void ggml_compute_forward_out_prod_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_ASSERT(ne0 == ne00);
+ GGML_ASSERT(ne1 == ne10);
+ GGML_ASSERT(ne2 == ne02);
+ GGML_ASSERT(ne02 == ne12);
+ GGML_ASSERT(ne3 == ne13);
+ GGML_ASSERT(ne03 == ne13);
+
+ // we don't support permuted src0 or src1
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ // GGML_ASSERT(nb0 <= nb1);
+ // GGML_ASSERT(nb1 <= nb2);
+ // GGML_ASSERT(nb2 <= nb3);
+
+ // nb01 >= nb00 - src0 is not transposed
+ // compute by src0 rows
+
+ if (ith == 0) {
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+ }
+ ggml_barrier(params->shared);
+
+ // dst[:,:,:,:] = 0
+ // for i2,i3:
+ // for i1:
+ // for i01:
+ // for i0:
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+
+ // parallelize by last three dimensions
+
+ // total rows in dst
+ const int64_t nr = ne1*ne2*ne3;
+
+ // rows per thread
+ const int64_t dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int64_t ir0 = dr*ith;
+ const int64_t ir1 = MIN(ir0 + dr, nr);
+
+ // block-tiling attempt
+ const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
+ const int64_t blck_1 = 16;
+
+ for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
+ const int64_t bir1 = MIN(bir + blck_1, ir1);
+ for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
+ const int64_t bne01 = MIN(bi01 + blck_0, ne01);
+ for (int64_t ir = bir; ir < bir1; ++ir) {
+ // dst indices
+ const int64_t i3 = ir/(ne2*ne1);
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ const int64_t i02 = i2;
+ const int64_t i03 = i3;
+
+ //const int64_t i10 = i1;
+ const int64_t i12 = i2;
+ const int64_t i13 = i3;
+
+#if GGML_VEC_MAD_UNROLL > 2
+ const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
+ for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
+ const int64_t i11 = i01;
+
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
+
+ ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
+ }
+ for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
+ const int64_t i11 = i01;
+
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
+
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
+ }
+#else
+ for (int64_t i01 = bi01; i01 < bne01; ++i01) {
+ const int64_t i11 = i01;
+
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
+
+ ggml_vec_mad_f32(ne0, d, s0, *s1);
+ }
+#endif
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_out_prod_q_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const enum ggml_type type = src0->type;
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
+
+ GGML_ASSERT(ne02 == ne12);
+ GGML_ASSERT(ne03 == ne13);
+ GGML_ASSERT(ne2 == ne12);
+ GGML_ASSERT(ne3 == ne13);
+
+ // we don't support permuted src0 dim0
+ GGML_ASSERT(nb00 == ggml_type_size(type));
+
+ // dst dim0 cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ // GGML_ASSERT(nb0 <= nb1);
+ // GGML_ASSERT(nb1 <= nb2);
+ // GGML_ASSERT(nb2 <= nb3);
+
+ GGML_ASSERT(ne0 == ne00);
+ GGML_ASSERT(ne1 == ne10);
+ GGML_ASSERT(ne2 == ne02);
+ GGML_ASSERT(ne3 == ne03);
+
+ // nb01 >= nb00 - src0 is not transposed
+ // compute by src0 rows
+
+ if (ith == 0) {
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+ }
+ ggml_barrier(params->shared);
+
+ // parallelize by last three dimensions
+
+ // total rows in dst
+ const int64_t nr = ne1*ne2*ne3;
+
+ // rows per thread
+ const int64_t dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int64_t ir0 = dr*ith;
+ const int64_t ir1 = MIN(ir0 + dr, nr);
+
+ // dst[:,:,:,:] = 0
+ // for i2,i3:
+ // for i1:
+ // for i01:
+ // for i0:
+ // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+
+ float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
+
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
+ // dst indices
+ const int64_t i3 = ir/(ne2*ne1);
+ const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
+ const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ const int64_t i02 = i2;
+ const int64_t i03 = i3;
+
+ //const int64_t i10 = i1;
+ const int64_t i12 = i2;
+ const int64_t i13 = i3;
+
+ for (int64_t i01 = 0; i01 < ne01; ++i01) {
+ const int64_t i11 = i01;
+
+ float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
+ float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
+
+ dequantize_row_q(s0, wdata, ne0);
+ ggml_vec_mad_f32(ne0, d, wdata, *s1);
+ }
+ }
+}
+
+static void ggml_compute_forward_out_prod(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_Q4_0_4_4:
+ case GGML_TYPE_Q4_0_4_8:
+ case GGML_TYPE_Q4_0_8_8:
+ {
+ ggml_compute_forward_out_prod_q_f32(params, dst);
+ } break;
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(false); // todo
+ // ggml_compute_forward_out_prod_f16_f32(params, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_out_prod_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_scale
+
+static void ggml_compute_forward_scale_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ // scale factor
+ float v;
+ memcpy(&v, dst->op_params, sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ const size_t nb01 = src0->nb[1];
+
+ const size_t nb1 = dst->nb[1];
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ if (dst->data != src0->data) {
+ // src0 is same shape as dst => same indices
+ memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
+ }
+ ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
+ }
+}
+
+static void ggml_compute_forward_scale(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_scale_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_set
+
+static void ggml_compute_forward_set_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+
+ // view src0 and dst with these strides and data offset inbytes during set
+ // nb0 is implicitly element_size because src0 and dst are contiguous
+ size_t nb1 = ((int32_t *) dst->op_params)[0];
+ size_t nb2 = ((int32_t *) dst->op_params)[1];
+ size_t nb3 = ((int32_t *) dst->op_params)[2];
+ size_t offset = ((int32_t *) dst->op_params)[3];
+ bool inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+ if (!inplace) {
+ if (params->ith == 0) {
+ // memcpy needs to be synchronized across threads to avoid race conditions.
+ // => do it in INIT phase
+ memcpy(
+ ((char *) dst->data),
+ ((char *) src0->data),
+ ggml_nbytes(dst));
+ }
+ ggml_barrier(params->shared);
+ }
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src1);
+ const int nc = src1->ne[0];
+
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
+
+ // src0 and dst as viewed during set
+ const size_t nb0 = ggml_element_size(src0);
+
+ const int im0 = (ne10 == 0 ? 0 : ne10-1);
+ const int im1 = (ne11 == 0 ? 0 : ne11-1);
+ const int im2 = (ne12 == 0 ? 0 : ne12-1);
+ const int im3 = (ne13 == 0 ? 0 : ne13-1);
+
+ GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
+
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 and dst are viewed with shape of src1 and offset
+ // => same indices
+ const int i3 = ir/(ne12*ne11);
+ const int i2 = (ir - i3*ne12*ne11)/ne11;
+ const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
+
+ ggml_vec_cpy_f32(nc,
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
+ (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+ }
+}
+
+static void ggml_compute_forward_set(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_set_f32(params, dst);
+ } break;
+ case GGML_TYPE_F16:
+ case GGML_TYPE_BF16:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_Q4_0_4_4:
+ case GGML_TYPE_Q4_0_4_8:
+ case GGML_TYPE_Q4_0_8_8:
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_cpy
+
+static void ggml_compute_forward_cpy(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ ggml_compute_forward_dup(params, dst);
+}
+
+// ggml_compute_forward_cont
+
+static void ggml_compute_forward_cont(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ ggml_compute_forward_dup(params, dst);
+}
+
+// ggml_compute_forward_reshape
+
+static void ggml_compute_forward_reshape(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ // NOP
+ UNUSED(params);
+ UNUSED(dst);
+}
+
+// ggml_compute_forward_view
+
+static void ggml_compute_forward_view(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * dst) {
+ // NOP
+ UNUSED(params);
+ UNUSED(dst);
+}
+
+// ggml_compute_forward_permute
+
+static void ggml_compute_forward_permute(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * dst) {
+ // NOP
+ UNUSED(params);
+ UNUSED(dst);
+}
+
+// ggml_compute_forward_transpose
+
+static void ggml_compute_forward_transpose(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * dst) {
+ // NOP
+ UNUSED(params);
+ UNUSED(dst);
+}
+
+// ggml_compute_forward_get_rows
+
+static void ggml_compute_forward_get_rows_q(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int64_t nc = ne00;
+ const int64_t nr = ggml_nelements(src1);
+
+ const enum ggml_type type = src0->type;
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
+
+ assert(ne0 == nc);
+ assert(ne02 == ne11);
+ assert(nb00 == ggml_type_size(type));
+ assert(ggml_nrows(dst) == nr);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int64_t i = ir0; i < ir1; ++i) {
+ const int64_t i12 = i/(ne11*ne10);
+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+ assert(i01 >= 0 && i01 < ne01);
+
+ dequantize_row_q(
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
+ }
+}
+
+static void ggml_compute_forward_get_rows_f16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int64_t nc = ne00;
+ const int64_t nr = ggml_nelements(src1);
+
+ assert(ne0 == nc);
+ assert(ne02 == ne11);
+ assert(nb00 == sizeof(ggml_fp16_t));
+ assert(ggml_nrows(dst) == nr);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int64_t i = ir0; i < ir1; ++i) {
+ const int64_t i12 = i/(ne11*ne10);
+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+ assert(i01 >= 0 && i01 < ne01);
+
+ ggml_fp16_to_fp32_row(
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
+ }
+}
+
+static void ggml_compute_forward_get_rows_bf16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int64_t nc = ne00;
+ const int64_t nr = ggml_nelements(src1);
+
+ assert(ne0 == nc);
+ assert(ne02 == ne11);
+ assert(nb00 == sizeof(ggml_bf16_t));
+ assert(ggml_nrows(dst) == nr);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int64_t i = ir0; i < ir1; ++i) {
+ const int64_t i12 = i/(ne11*ne10);
+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+ assert(i01 >= 0 && i01 < ne01);
+
+ ggml_bf16_to_fp32_row(
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
+ }
+}
+
+static void ggml_compute_forward_get_rows_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int64_t nc = ne00;
+ const int64_t nr = ggml_nelements(src1);
+
+ assert(ne0 == nc);
+ assert(ne02 == ne11);
+ assert(nb00 == sizeof(float));
+ assert(ggml_nrows(dst) == nr);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int64_t i = ir0; i < ir1; ++i) {
+ const int64_t i12 = i/(ne11*ne10);
+ const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+ const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+ assert(i01 >= 0 && i01 < ne01);
+
+ ggml_vec_cpy_f32(nc,
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
+ (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
+ }
+}
+
+static void ggml_compute_forward_get_rows(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_Q4_0_4_4:
+ case GGML_TYPE_Q4_0_4_8:
+ case GGML_TYPE_Q4_0_8_8:
+ {
+ ggml_compute_forward_get_rows_q(params, dst);
+ } break;
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_get_rows_f16(params, dst);
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ ggml_compute_forward_get_rows_bf16(params, dst);
+ } break;
+ case GGML_TYPE_F32:
+ case GGML_TYPE_I32:
+ {
+ ggml_compute_forward_get_rows_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+
+ //static bool first = true;
+ //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+ //if (first) {
+ // first = false;
+ //} else {
+ // for (int k = 0; k < dst->ne[1]; ++k) {
+ // for (int j = 0; j < dst->ne[0]/16; ++j) {
+ // for (int i = 0; i < 16; ++i) {
+ // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
+ // }
+ // printf("\n");
+ // }
+ // printf("\n");
+ // }
+ // printf("\n");
+ // exit(0);
+ //}
+}
+
+// ggml_compute_forward_get_rows_back
+
+static void ggml_compute_forward_get_rows_back_f32_f16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ // ggml_compute_forward_dup_same_cont(params, opt0, dst);
+
+ memset(dst->data, 0, ggml_nbytes(dst));
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nelements(src1);
+
+ GGML_ASSERT( dst->ne[0] == nc);
+ GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
+
+ for (int i = 0; i < nr; ++i) {
+ const int r = ((int32_t *) src1->data)[i];
+
+ for (int j = 0; j < nc; ++j) {
+ ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
+ ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v);
+ }
+ }
+}
+
+static void ggml_compute_forward_get_rows_back_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ // ggml_compute_forward_dup_same_cont(params, opt0, dst);
+
+ memset(dst->data, 0, ggml_nbytes(dst));
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nelements(src1);
+
+ GGML_ASSERT( dst->ne[0] == nc);
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < nr; ++i) {
+ const int r = ((int32_t *) src1->data)[i];
+
+ ggml_vec_add_f32(nc,
+ (float *) ((char *) dst->data + r*dst->nb[1]),
+ (float *) ((char *) dst->data + r*dst->nb[1]),
+ (float *) ((char *) src0->data + i*src0->nb[1]));
+ }
+}
+
+static void ggml_compute_forward_get_rows_back(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_get_rows_back_f32_f16(params, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_get_rows_back_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+
+ //static bool first = true;
+ //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+ //if (first) {
+ // first = false;
+ //} else {
+ // for (int k = 0; k < dst->ne[1]; ++k) {
+ // for (int j = 0; j < dst->ne[0]/16; ++j) {
+ // for (int i = 0; i < 16; ++i) {
+ // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
+ // }
+ // printf("\n");
+ // }
+ // printf("\n");
+ // }
+ // printf("\n");
+ // exit(0);
+ //}
+}
+
+// ggml_compute_forward_diag
+
+static void ggml_compute_forward_diag_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ // TODO: handle transposed/permuted matrices
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT(ne00 == ne0);
+ GGML_ASSERT(ne00 == ne1);
+ GGML_ASSERT(ne01 == 1);
+ GGML_ASSERT(ne02 == ne2);
+ GGML_ASSERT(ne03 == ne3);
+
+ GGML_ASSERT(nb00 == sizeof(float));
+ GGML_ASSERT(nb0 == sizeof(float));
+
+ for (int i3 = 0; i3 < ne3; i3++) {
+ for (int i2 = 0; i2 < ne2; i2++) {
+ for (int i1 = 0; i1 < ne1; i1++) {
+ float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
+ float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
+ for (int i0 = 0; i0 < i1; i0++) {
+ d[i0] = 0;
+ }
+ d[i1] = s[i1];
+ for (int i0 = i1+1; i0 < ne0; i0++) {
+ d[i0] = 0;
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_diag(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_diag_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_diag_mask_inf
+
+static void ggml_compute_forward_diag_mask_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst,
+ const float value) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int n_past = ((int32_t *) dst->op_params)[0];
+ const bool inplace = src0->data == dst->data;
+
+ GGML_ASSERT(n_past >= 0);
+
+ if (!inplace) {
+ if (ith == 0) {
+ // memcpy needs to be synchronized across threads to avoid race conditions.
+ // => do it in INIT phase
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+ memcpy(
+ ((char *) dst->data),
+ ((char *) src0->data),
+ ggml_nbytes(dst));
+ }
+ ggml_barrier(params->shared);
+ }
+
+ // TODO: handle transposed/permuted matrices
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+ const int nr = src0->ne[1];
+ const int nz = n/nr;
+
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ for (int k = 0; k < nz; k++) {
+ for (int j = ith; j < nr; j += nth) {
+ for (int i = n_past; i < nc; i++) {
+ if (i > n_past + j) {
+ *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_diag_mask_inf(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+static void ggml_compute_forward_diag_mask_zero(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_diag_mask_f32(params, dst, 0);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_soft_max
+
+static void ggml_compute_forward_soft_max_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ assert(ggml_is_contiguous(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+
+ // TODO: handle transposed/permuted matrices
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ //const int64_t ne11 = src1 ? src1->ne[1] : 1;
+
+ // TODO: is this supposed to be ceil instead of floor?
+ // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
+ const uint32_t n_head = ne02;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
+
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ // ALiBi
+ const uint32_t h = (i1/ne01)%ne02; // head
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+ float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
+ float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
+
+ // broadcast the mask across rows
+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+ float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+
+ ggml_vec_cpy_f32 (nc, wp, sp);
+ ggml_vec_scale_f32(nc, wp, scale);
+ if (mp_f32) {
+ if (use_f16) {
+ for (int i = 0; i < nc; ++i) {
+ wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
+ }
+ } else {
+ for (int i = 0; i < nc; ++i) {
+ wp[i] += slope*mp_f32[i];
+ }
+ }
+ }
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ //printf("p[%d] = %f\n", i, p[i]);
+ assert(!isnan(wp[i]));
+ }
+#endif
+
+ float max = -INFINITY;
+ ggml_vec_max_f32(nc, &max, wp);
+
+ ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
+ assert(sum > 0.0);
+
+ sum = 1.0/sum;
+ ggml_vec_scale_f32(nc, dp, sum);
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ assert(!isnan(dp[i]));
+ assert(!isinf(dp[i]));
+ }
+#endif
+ }
+}
+
+static void ggml_compute_forward_soft_max(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_soft_max_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+
+// ggml_compute_forward_soft_max_back
+
+static void ggml_compute_forward_soft_max_back_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_are_same_shape(src1, dst));
+
+ // TODO: handle transposed/permuted matrices
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
+ float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
+ float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ //printf("p[%d] = %f\n", i, p[i]);
+ assert(!isnan(dy[i]));
+ assert(!isnan(y[i]));
+ }
+#endif
+ // Jii = yi - yi*yi
+ // Jij = -yi*yj
+ // J = diag(y)-y.T*y
+ // dx = J * dy
+ // dxk = sum_i(Jki * dyi)
+ // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
+ // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
+ // dxk = sum_i(-yk*yi * dyi) + yk*dyk
+ // dxk = -yk * sum_i(yi * dyi) + yk*dyk
+ // dxk = -yk * dot(y, dy) + yk*dyk
+ // dxk = yk * (- dot(y, dy) + dyk)
+ // dxk = yk * (dyk - dot(y, dy))
+ //
+ // post-order:
+ // dot_y_dy := dot(y, dy)
+ // dx := dy
+ // dx := dx - dot_y_dy
+ // dx := dx * y
+
+ // linear runtime, no additional memory
+ float dot_y_dy = 0;
+ ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
+ ggml_vec_cpy_f32 (nc, dx, dy);
+ ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
+ ggml_vec_mul_f32 (nc, dx, dx, y);
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ assert(!isnan(dx[i]));
+ assert(!isinf(dx[i]));
+ }
+#endif
+ }
+}
+
+static void ggml_compute_forward_soft_max_back(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_soft_max_back_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_clamp
+
+static void ggml_compute_forward_clamp_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ float min;
+ float max;
+ memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
+ memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ const size_t nb00 = src0->nb[0];
+ const size_t nb01 = src0->nb[1];
+
+ const size_t nb0 = dst->nb[0];
+ const size_t nb1 = dst->nb[1];
+
+ GGML_ASSERT( nb0 == sizeof(float));
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ for (int j = ith; j < n; j += nth) {
+ float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
+ float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
+
+ for (int i = 0; i < nc; i++) {
+ dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
+ }
+ }
+}
+
+static void ggml_compute_forward_clamp(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_clamp_f32(params, dst);
+ } break;
+ case GGML_TYPE_F16:
+ case GGML_TYPE_BF16:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q8_1:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ1_BN:
+ case GGML_TYPE_IQ2_BN:
+ case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_Q8_K:
+ case GGML_TYPE_Q8_K64:
+ case GGML_TYPE_Q4_0_4_4:
+ case GGML_TYPE_Q4_0_4_8:
+ case GGML_TYPE_Q4_0_8_8:
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_I64:
+ case GGML_TYPE_F64:
+ case GGML_TYPE_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_rope
+
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+ const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
+ return 1 - MIN(1, MAX(0, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+ float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+ float * cos_theta, float * sin_theta) {
+ // Get n-d rotational scaling corrected for extrapolation
+ float theta_interp = freq_scale * theta_extrap;
+ float theta = theta_interp;
+ if (ext_factor != 0.0f) {
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+ // Get n-d magnitude scaling corrected for interpolation
+ mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+ }
+ *cos_theta = cosf(theta) * mscale;
+ *sin_theta = sinf(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
+ return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
+}
+
+static void ggml_rope_cache_init(
+ float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
+ float * cache, float sin_sign, float theta_scale) {
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
+ float theta = theta_base;
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+ const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
+ rope_yarn(
+ theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
+ );
+ cache[i0 + 1] *= sin_sign;
+
+ theta *= theta_scale;
+ }
+}
+
+GGML_CALL void ggml_rope_yarn_corr_dims(
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+ // start and end correction dims
+ float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
+ float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
+ dims[0] = MAX(0, start);
+ dims[1] = MIN(n_dims - 1, end);
+}
+
+static void ggml_compute_forward_rope_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst,
+ const bool forward) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+ const struct ggml_tensor * src2 = dst->src[2];
+
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+
+ //const int n_past = ((int32_t *) dst->op_params)[0];
+ const int n_dims = ((int32_t *) dst->op_params)[1];
+ const int mode = ((int32_t *) dst->op_params)[2];
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+ //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+ GGML_ASSERT(nb00 == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(dst);
+
+ GGML_ASSERT(n_dims <= ne0);
+ GGML_ASSERT(n_dims % 2 == 0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ // row index used to determine which thread to use
+ int ir = 0;
+
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+ float corr_dims[2];
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+ const bool is_neox = mode & 2;
+
+ const float * freq_factors = NULL;
+ if (src2 != NULL) {
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+ freq_factors = (const float *) src2->data;
+ }
+
+ // backward process uses inverse rotation by cos and sin.
+ // cos and sin build a rotation matrix, where the inverse is the transpose.
+ // this essentially just switches the sign of sin.
+ const float sin_sign = forward ? 1.0f : -1.0f;
+
+ const int32_t * pos = (const int32_t *) src1->data;
+
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
+ const int64_t p = pos[i2];
+
+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
+ if (ir++ < ir0) continue;
+ if (ir > ir1) break;
+
+ if (!is_neox) {
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+ const float cos_theta = cache[i0 + 0];
+ const float sin_theta = cache[i0 + 1];
+
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ const float x0 = src[0];
+ const float x1 = src[1];
+
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
+ }
+ } else {
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+ const int64_t ic = i0/2;
+
+ const float cos_theta = cache[i0 + 0];
+ const float sin_theta = cache[i0 + 1];
+
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
+
+ const float x0 = src[0];
+ const float x1 = src[n_dims/2];
+
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+ }
+ }
+
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ dst_data[0] = src[0];
+ dst_data[1] = src[1];
+ }
+ }
+ }
+ }
+}
+
+// TODO: deduplicate f16/f32 code
+static void ggml_compute_forward_rope_f16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst,
+ const bool forward) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+ const struct ggml_tensor * src2 = dst->src[2];
+
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+
+ //const int n_past = ((int32_t *) dst->op_params)[0];
+ const int n_dims = ((int32_t *) dst->op_params)[1];
+ const int mode = ((int32_t *) dst->op_params)[2];
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+ //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+ GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(dst);
+
+ GGML_ASSERT(n_dims <= ne0);
+ GGML_ASSERT(n_dims % 2 == 0);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ // row index used to determine which thread to use
+ int ir = 0;
+
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+ float corr_dims[2];
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+ const bool is_neox = mode & 2;
+
+ const float * freq_factors = NULL;
+ if (src2 != NULL) {
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+ freq_factors = (const float *) src2->data;
+ }
+
+ // backward process uses inverse rotation by cos and sin.
+ // cos and sin build a rotation matrix, where the inverse is the transpose.
+ // this essentially just switches the sign of sin.
+ const float sin_sign = forward ? 1.0f : -1.0f;
+
+ const int32_t * pos = (const int32_t *) src1->data;
+
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
+ const int64_t p = pos[i2];
+
+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
+ if (ir++ < ir0) continue;
+ if (ir > ir1) break;
+
+ if (!is_neox) {
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+ const float cos_theta = cache[i0 + 0];
+ const float sin_theta = cache[i0 + 1];
+
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
+ const float x1 = GGML_FP16_TO_FP32(src[1]);
+
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+ dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+ }
+ } else {
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+ const int64_t ic = i0/2;
+
+ const float cos_theta = cache[i0 + 0];
+ const float sin_theta = cache[i0 + 1];
+
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
+
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+ dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+ }
+ }
+
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ dst_data[0] = src[0];
+ dst_data[1] = src[1];
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_rope(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_rope_f16(params, dst, true);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_rope_f32(params, dst, true);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_rope_back
+
+static void ggml_compute_forward_rope_back(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_rope_f16(params, dst, false);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_rope_f32(params, dst, false);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_conv_transpose_1d
+
+static void ggml_compute_forward_conv_transpose_1d_f16_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nk = ne00*ne01*ne02;
+
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ if (ith == 0) {
+ memset(params->wdata, 0, params->wsize);
+
+ // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
+ {
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
+ ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ dst_data[i00*ne02 + i02] = src[i00];
+ }
+ }
+ }
+ }
+
+ // permute source data (src1) from (L x Cin) to (Cin x L)
+ {
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
+ ggml_fp16_t * dst_data = wdata;
+
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
+ dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
+ }
+ }
+ }
+
+ // need to zero dst since we are accumulating into it
+ memset(dst->data, 0, ggml_nbytes(dst));
+ }
+ ggml_barrier(params->shared);
+
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+
+ // total rows in dst
+ const int nr = ne1;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+ ggml_fp16_t * const wdata_src = wdata + nk;
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float * dst_data = (float *)((char *) dst->data + i1*nb1);
+ ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
+ for (int i10 = 0; i10 < ne10; i10++) {
+ const int i1n = i10*ne11;
+ for (int i00 = 0; i00 < ne00; i00++) {
+ float v = 0;
+ ggml_vec_dot_f16(ne02, &v, 0,
+ (ggml_fp16_t *) wdata_src + i1n, 0,
+ (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
+ dst_data[i10*s0 + i00] += v;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_conv_transpose_1d_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nk = ne00*ne01*ne02;
+
+ GGML_ASSERT(nb00 == sizeof(float));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ if (ith == 0) {
+ memset(params->wdata, 0, params->wsize);
+
+ // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
+ {
+ float * const wdata = (float *) params->wdata + 0;
+
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
+ const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
+ float * dst_data = wdata + i01*ne00*ne02;
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ dst_data[i00*ne02 + i02] = src[i00];
+ }
+ }
+ }
+ }
+
+ // prepare source data (src1)
+ {
+ float * const wdata = (float *) params->wdata + nk;
+ float * dst_data = wdata;
+
+ for (int64_t i11 = 0; i11 < ne11; i11++) {
+ const float * const src = (float *)((char *) src1->data + i11*nb11);
+ for (int64_t i10 = 0; i10 < ne10; i10++) {
+ dst_data[i10*ne11 + i11] = src[i10];
+ }
+ }
+ }
+
+ // need to zero dst since we are accumulating into it
+ memset(dst->data, 0, ggml_nbytes(dst));
+ }
+ ggml_barrier(params->shared);
+
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+
+ // total rows in dst
+ const int nr = ne1;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ float * const wdata = (float *) params->wdata + 0;
+ float * const wdata_src = wdata + nk;
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float * dst_data = (float *)((char *) dst->data + i1*nb1);
+ float * wdata_kernel = wdata + i1*ne02*ne00;
+ for (int i10 = 0; i10 < ne10; i10++) {
+ const int i1n = i10*ne11;
+ for (int i00 = 0; i00 < ne00; i00++) {
+ float v = 0;
+ ggml_vec_dot_f32(ne02, &v, 0,
+ wdata_src + i1n, 0,
+ wdata_kernel + i00*ne02, 0, 1);
+ dst_data[i10*s0 + i00] += v;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_conv_transpose_1d(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_conv_transpose_1d_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// src0: kernel [OC, IC, KH, KW]
+// src1: image [N, IC, IH, IW]
+// dst: result [N, OH, OW, IC*KH*KW]
+static void ggml_compute_forward_im2col_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t N = is_2D ? ne13 : ne12;
+ const int64_t IC = is_2D ? ne12 : ne11;
+ const int64_t IH = is_2D ? ne11 : 1;
+ const int64_t IW = ne10;
+
+ const int64_t KH = is_2D ? ne01 : 1;
+ const int64_t KW = ne00;
+
+ const int64_t OH = is_2D ? ne2 : 1;
+ const int64_t OW = ne1;
+
+ int ofs0 = is_2D ? nb13 : nb12;
+ int ofs1 = is_2D ? nb12 : nb11;
+
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+ {
+ float * const wdata = (float *) dst->data;
+
+ for (int64_t in = 0; in < N; in++) {
+ for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
+ for (int64_t iow = 0; iow < OW; iow++) {
+ for (int64_t iic = ith; iic < IC; iic += nth) {
+
+ // micro kernel
+ float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+ const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
+
+ for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
+ } else {
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+
+// src0: kernel [OC, IC, KH, KW]
+// src1: image [N, IC, IH, IW]
+// dst: result [N, OH, OW, IC*KH*KW]
+static void ggml_compute_forward_im2col_f16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t N = is_2D ? ne13 : ne12;
+ const int64_t IC = is_2D ? ne12 : ne11;
+ const int64_t IH = is_2D ? ne11 : 1;
+ const int64_t IW = ne10;
+
+ const int64_t KH = is_2D ? ne01 : 1;
+ const int64_t KW = ne00;
+
+ const int64_t OH = is_2D ? ne2 : 1;
+ const int64_t OW = ne1;
+
+ int ofs0 = is_2D ? nb13 : nb12;
+ int ofs1 = is_2D ? nb12 : nb11;
+
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+ {
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
+
+ for (int64_t in = 0; in < N; in++) {
+ for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
+ for (int64_t iow = 0; iow < OW; iow++) {
+ for (int64_t iic = ith; iic < IC; iic += nth) {
+
+ // micro kernel
+ ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+ const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
+
+ for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
+ } else {
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_im2col(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ switch (dst->type) {
+ case GGML_TYPE_F16:
+ {
+ ggml_compute_forward_im2col_f16(params, dst);
+ } break;
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_im2col_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+
+// ggml_compute_forward_conv_transpose_2d
+
+static void ggml_compute_forward_conv_transpose_2d(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nk = ne00*ne01*ne02*ne03;
+
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ if (ith == 0) {
+ memset(params->wdata, 0, params->wsize);
+
+ // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
+ {
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
+ ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
+ }
+ }
+ }
+ }
+ }
+
+ // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
+ {
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
+ for (int i12 = 0; i12 < ne12; i12++) {
+ for (int i11 = 0; i11 < ne11; i11++) {
+ const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
+ ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
+ for (int i10 = 0; i10 < ne10; i10++) {
+ dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]);
+ }
+ }
+ }
+ }
+
+ memset(dst->data, 0, ggml_nbytes(dst));
+ }
+ ggml_barrier(params->shared);
+
+ const int32_t stride = ggml_get_op_params_i32(dst, 0);
+
+ // total patches in dst
+ const int np = ne2;
+
+ // patches per thread
+ const int dp = (np + nth - 1)/nth;
+
+ // patch range for this thread
+ const int ip0 = dp*ith;
+ const int ip1 = MIN(ip0 + dp, np);
+
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+ ggml_fp16_t * const wdata_src = wdata + nk;
+
+ for (int i2 = ip0; i2 < ip1; i2++) { // Cout
+ float * dst_data = (float *)((char *) dst->data + i2*nb2);
+ ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
+ for (int i11 = 0; i11 < ne11; i11++) {
+ for (int i10 = 0; i10 < ne10; i10++) {
+ const int i1n = i11*ne10*ne12 + i10*ne12;
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ float v = 0;
+ ggml_vec_dot_f16(ne03, &v, 0,
+ wdata_src + i1n, 0,
+ wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
+ dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
+ }
+ }
+ }
+ }
+ }
+}
+
+// ggml_compute_forward_pool_1d_sk_p0
+
+static void ggml_compute_forward_pool_1d_sk_p0(
+ const struct ggml_compute_params * params,
+ const enum ggml_op_pool op,
+ const int k,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src = dst->src[0];
+
+ assert(src->type == GGML_TYPE_F32);
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ const char * cdata = (const char *)src->data;
+ const char * const data_end = cdata + ggml_nbytes(src);
+ float * drow = (float *)dst->data;
+
+ const int64_t rs = dst->ne[0];
+
+ while (cdata < data_end) {
+ const float * const srow = (const float *)cdata;
+
+ int j = 0;
+
+ for (int64_t i = 0; i < rs; ++i) {
+ switch (op) {
+ case GGML_OP_POOL_AVG: drow[i] = 0; break;
+ case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
+ }
+ for (int ki = 0; ki < k; ++ki) {
+ switch (op) {
+ case GGML_OP_POOL_AVG: drow[i] += srow[j]; break;
+ case GGML_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j]; break;
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
+ }
+ ++j;
+ }
+ switch (op) {
+ case GGML_OP_POOL_AVG: drow[i] /= k; break;
+ case GGML_OP_POOL_MAX: break;
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
+ }
+ }
+
+ cdata += src->nb[1];
+ drow += rs;
+ }
+}
+
+// ggml_compute_forward_pool_1d
+
+static void ggml_compute_forward_pool_1d(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const int32_t * opts = (const int32_t *)dst->op_params;
+ enum ggml_op_pool op = opts[0];
+ const int k0 = opts[1];
+ const int s0 = opts[2];
+ const int p0 = opts[3];
+ GGML_ASSERT(p0 == 0); // padding not supported
+ GGML_ASSERT(k0 == s0); // only s = k supported
+
+ ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
+}
+
+// ggml_compute_forward_pool_2d
+
+static void ggml_compute_forward_pool_2d(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src = dst->src[0];
+
+ GGML_ASSERT(src->type == GGML_TYPE_F32);
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ const int32_t * opts = (const int32_t *)dst->op_params;
+ enum ggml_op_pool op = opts[0];
+ const int k0 = opts[1];
+ const int k1 = opts[2];
+ const int s0 = opts[3];
+ const int s1 = opts[4];
+ const int p0 = opts[5];
+ const int p1 = opts[6];
+ const char * cdata = (const char*)src->data;
+ const char * const data_end = cdata + ggml_nbytes(src);
+
+ const int64_t px = dst->ne[0];
+ const int64_t py = dst->ne[1];
+ const int64_t pa = px * py;
+
+ float * dplane = (float *)dst->data;
+
+ const int ka = k0 * k1;
+ const int offset0 = -p0;
+ const int offset1 = -p1;
+
+ while (cdata < data_end) {
+ for (int oy = 0; oy < py; ++oy) {
+ float * const drow = dplane + oy * px;
+ for (int ox = 0; ox < px; ++ox) {
+ float * const out = drow + ox;
+ switch (op) {
+ case GGML_OP_POOL_AVG: *out = 0; break;
+ case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
+ }
+
+ const int ix = offset0 + ox * s0;
+ const int iy = offset1 + oy * s1;
+
+ for (int ky = 0; ky < k1; ++ky) {
+ if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
+ const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky));
+ for (int kx = 0; kx < k0; ++kx) {
+ int j = ix + kx;
+ if (j < 0 || j >= src->ne[0]) continue;
+ switch (op) {
+ case GGML_OP_POOL_AVG: *out += srow[j]; break;
+ case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break;
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
+ }
+ }
+ }
+ switch (op) {
+ case GGML_OP_POOL_AVG: *out /= ka; break;
+ case GGML_OP_POOL_MAX: break;
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
+ }
+ }
+ }
+
+ cdata += src->nb[2];
+ dplane += pa;
+ }
+}
+
+// ggml_compute_forward_upscale
+
+static void ggml_compute_forward_upscale_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ const float sf0 = (float)ne0/src0->ne[0];
+ const float sf1 = (float)ne1/src0->ne[1];
+ const float sf2 = (float)ne2/src0->ne[2];
+ const float sf3 = (float)ne3/src0->ne[3];
+
+ // TODO: optimize
+
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
+ const int64_t i03 = i3 / sf3;
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
+ const int64_t i02 = i2 / sf2;
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
+ const int64_t i01 = i1 / sf1;
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
+ const int64_t i00 = i0 / sf0;
+
+ const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+ float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
+
+ *y = *x;
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_upscale(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_upscale_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+
+// ggml_compute_forward_pad
+
+static void ggml_compute_forward_pad_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ float * dst_ptr = (float *) dst->data;
+
+ // TODO: optimize
+
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
+ for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
+ for (int64_t i3 = 0; i3 < ne3; ++i3) {
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+
+ const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ dst_ptr[dst_idx] = *src_ptr;
+ } else {
+ dst_ptr[dst_idx] = 0;
+ }
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_pad(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_pad_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+
+// ggml_compute_forward_arange
+
+static void ggml_compute_forward_arange_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ GGML_ASSERT(dst->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const float start = ggml_get_op_params_f32(dst, 0);
+ const float stop = ggml_get_op_params_f32(dst, 1);
+ const float step = ggml_get_op_params_f32(dst, 2);
+
+ const int64_t steps = (int64_t) ceilf((stop - start) / step);
+
+ GGML_ASSERT(ggml_nelements(dst) == steps);
+
+ for (int64_t i = ith; i < steps; i+= nth) {
+ float value = start + step * i;
+ ((float *)dst->data)[i] = value;
+ }
+}
+
+static void ggml_compute_forward_arange(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ switch (dst->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_arange_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+static void ggml_compute_forward_timestep_embedding_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ const int dim = ggml_get_op_params_i32(dst, 0);
+ const int max_period = ggml_get_op_params_i32(dst, 1);
+
+ int half = dim / 2;
+
+ for (int64_t i = 0; i < ne00; i++) {
+ float * embed_data = (float *)((char *) dst->data + i*nb1);
+ for (int64_t j = ith; j < half; j += nth) {
+ float timestep = ((float *)src0->data)[i];
+ float freq = (float)expf(-logf(max_period) * j / half);
+ float arg = timestep * freq;
+ embed_data[j] = cosf(arg);
+ embed_data[j + half] = sinf(arg);
+ }
+ if (dim % 2 != 0 && ith == 0) {
+ embed_data[dim] = 0.f;
+ }
+ }
+}
+
+static void ggml_compute_forward_timestep_embedding(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_timestep_embedding_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_argsort
+
+static void ggml_compute_forward_argsort_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT(nb0 == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t nr = ggml_nrows(src0);
+
+ enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0);
+
+ for (int64_t i = ith; i < nr; i += nth) {
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
+
+ for (int64_t j = 0; j < ne0; j++) {
+ dst_data[j] = j;
+ }
+
+ // C doesn't have a functional sort, so we do a bubble sort instead
+ for (int64_t j = 0; j < ne0; j++) {
+ for (int64_t k = j + 1; k < ne0; k++) {
+ if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
+ (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
+ int32_t tmp = dst_data[j];
+ dst_data[j] = dst_data[k];
+ dst_data[k] = tmp;
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_argsort(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_argsort_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_flash_attn_ext
+
+static void ggml_compute_forward_flash_attn_ext_f16(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * q,
+ const struct ggml_tensor * k,
+ const struct ggml_tensor * v,
+ const struct ggml_tensor * mask,
+ struct ggml_tensor * dst) {
+
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t D = neq0;
+ const int64_t N = neq1;
+
+ GGML_ASSERT(ne0 == D);
+ GGML_ASSERT(ne2 == N);
+
+ // input tensor rows must be contiguous
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
+
+ GGML_ASSERT(neq0 == D);
+ GGML_ASSERT(nek0 == D);
+ GGML_ASSERT(nev0 == D);
+
+ GGML_ASSERT(neq1 == N);
+ GGML_ASSERT(nev0 == D);
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ // broadcast factors
+ const int64_t rk2 = neq2/nek2;
+ const int64_t rk3 = neq3/nek3;
+
+ const int64_t rv2 = neq2/nev2;
+ const int64_t rv3 = neq3/nev3;
+
+ // parallelize by q rows using ggml_vec_dot_f32
+
+ // total rows in q
+ const int nr = neq1*neq2*neq3;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+
+ const uint32_t n_head = neq2;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type;
+ ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float;
+ ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
+ ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
+
+ // loop over n_batch and n_head
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // q indices
+ const int iq3 = ir/(neq2*neq1);
+ const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+ const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+ const uint32_t h = iq2; // head index
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+ float S = 0.0f; // sum
+ float M = -INFINITY; // maximum KQ value
+
+ float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
+ float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
+ ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
+
+ if (v->type == GGML_TYPE_F16) {
+ memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
+ } else {
+ memset(VKQ32, 0, D*sizeof(float));
+ }
+
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
+
+ // k indices
+ const int ik3 = iq3 / rk3;
+ const int ik2 = iq2 / rk2;
+
+ // v indices
+ const int iv3 = iq3 / rv3;
+ const int iv2 = iq2 / rv2;
+
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
+ q_to_vec_dot(pq, Q_q, D);
+
+ // online softmax / attention
+ // loop over n_kv and n_head_kv
+ // ref: https://arxiv.org/pdf/2112.05682.pdf
+ for (int64_t ic = 0; ic < nek1; ++ic) {
+ const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
+ if (mv == -INFINITY) {
+ continue;
+ }
+
+ float s; // KQ value
+
+ const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
+ kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
+
+ s = s*scale + mv; // scale KQ value and apply mask
+
+ const float Mold = M;
+
+ float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
+ float vs = 1.0f; // post-softmax KQ value, expf(s - M)
+
+ const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
+
+ if (v->type== GGML_TYPE_F16) {
+ if (s > M) {
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
+ M = s;
+ ms = expf(Mold - M);
+
+ // V = V*expf(Mold - M)
+ ggml_vec_scale_f16(D, VKQ16, ms);
+ } else {
+ // no new maximum, ms == 1.0f, vs != 1.0f
+ vs = expf(s - M);
+ }
+
+ // V += v*expf(s - M)
+ ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
+ } else {
+ if (s > M) {
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
+ M = s;
+ ms = expf(Mold - M);
+
+ // V = V*expf(Mold - M)
+ ggml_vec_scale_f32(D, VKQ32, ms);
+ } else {
+ // no new maximum, ms == 1.0f, vs != 1.0f
+ vs = expf(s - M);
+ }
+
+ v_to_float(v_data, V32, D);
+
+ // V += v*expf(s - M)
+ ggml_vec_mad_f32(D, VKQ32, V32, vs);
+ }
+
+ S = S*ms + vs; // scale and increment sum with partial sum
+ }
+
+ if (v->type == GGML_TYPE_F16) {
+ for (int64_t d = 0; d < D; ++d) {
+ VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
+ }
+ }
+
+ // V /= S
+ const float S_inv = 1.0f/S;
+ ggml_vec_scale_f32(D, VKQ32, S_inv);
+
+ // dst indices
+ const int i1 = iq1;
+ const int i2 = iq2;
+ const int i3 = iq3;
+
+ // original
+ //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
+
+ // permute(0, 2, 1, 3)
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
+ }
+}
+
+static void ggml_compute_forward_flash_attn_ext(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * q,
+ const struct ggml_tensor * k,
+ const struct ggml_tensor * v,
+ const struct ggml_tensor * mask,
+ struct ggml_tensor * dst) {
+ switch (dst->op_params[2]) {
+ case GGML_PREC_DEFAULT:
+ case GGML_PREC_F32:
+ {
+ // uses F32 accumulators
+ ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_flash_attn_back
+
+static void ggml_compute_forward_flash_attn_back_f32(
+ const struct ggml_compute_params * params,
+ const bool masked,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * q = dst->src[0];
+ const struct ggml_tensor * k = dst->src[1];
+ const struct ggml_tensor * v = dst->src[2];
+ const struct ggml_tensor * d = dst->src[3];
+
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
+ GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
+ GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t D = neq0;
+ const int64_t N = neq1;
+ const int64_t P = nek1 - N;
+ const int64_t M = P + N;
+
+ const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
+ const int mxDM = MAX(D, Mup);
+
+ // GGML_ASSERT(ne0 == D);
+ // GGML_ASSERT(ne1 == N);
+ GGML_ASSERT(P >= 0);
+
+ GGML_ASSERT(nbq0 == sizeof(float));
+ GGML_ASSERT(nbk0 == sizeof(float));
+ GGML_ASSERT(nbv0 == sizeof(float));
+
+ GGML_ASSERT(neq0 == D);
+ GGML_ASSERT(nek0 == D);
+ GGML_ASSERT(nev1 == D);
+ GGML_ASSERT(ned0 == D);
+
+ GGML_ASSERT(neq1 == N);
+ GGML_ASSERT(nek1 == N + P);
+ GGML_ASSERT(nev1 == D);
+ GGML_ASSERT(ned1 == N);
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ if (ith == 0) {
+ memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
+ }
+ ggml_barrier(params->shared);
+
+ const int64_t elem_q = ggml_nelements(q);
+ const int64_t elem_k = ggml_nelements(k);
+
+ enum ggml_type result_type = dst->type;
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
+ const size_t tsize = ggml_type_size(result_type);
+
+ const size_t offs_q = 0;
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
+
+ void * grad_q = (char *) dst->data;
+ void * grad_k = (char *) dst->data + offs_k;
+ void * grad_v = (char *) dst->data + offs_v;
+
+ const size_t nbgq1 = nb0*neq0;
+ const size_t nbgq2 = nb0*neq0*neq1;
+ const size_t nbgq3 = nb0*neq0*neq1*neq2;
+
+ const size_t nbgk1 = nb0*nek0;
+ const size_t nbgk2 = nb0*nek0*nek1;
+ const size_t nbgk3 = nb0*nek0*nek1*neq2;
+
+ const size_t nbgv1 = nb0*nev0;
+ const size_t nbgv2 = nb0*nev0*nev1;
+ const size_t nbgv3 = nb0*nev0*nev1*neq2;
+
+ // parallelize by k rows using ggml_vec_dot_f32
+
+ // total rows in k
+ const int nr = nek2*nek3;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ const float scale = 1.0f/sqrtf(D);
+
+ //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
+
+ // how often k2 (and v2) is repeated in q2
+ int nrep = neq2/nek2;
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // q indices
+ const int ik3 = ir/(nek2);
+ const int ik2 = ir - ik3*nek2;
+
+ const int iq3 = ik3;
+ const int id3 = ik3;
+ const int iv3 = ik3;
+ const int iv2 = ik2;
+
+ for (int irep = 0; irep < nrep; ++irep) {
+ const int iq2 = ik2 + irep*nek2;
+ const int id2 = iq2;
+
+ // (ik2 + irep*nek2) % nek2 == ik2
+ for (int iq1 = 0; iq1 < neq1; ++iq1) {
+ const int id1 = iq1;
+
+ // not sure about CACHE_LINE_SIZE_F32..
+ // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
+ float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
+ float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
+
+ for (int i = M; i < Mup; ++i) {
+ S[i] = -INFINITY;
+ }
+
+ const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
+ // k indices
+ const int ik1 = ic;
+
+ // S indices
+ const int i1 = ik1;
+
+ ggml_vec_dot_f32(neq0,
+ S + i1, 0,
+ (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
+ }
+
+ // scale
+ ggml_vec_scale_f32(masked_begin, S, scale);
+
+ for (int64_t i = masked_begin; i < M; i++) {
+ S[i] = -INFINITY;
+ }
+
+ // softmax
+ // exclude known -INF S[..] values from max and loop
+ // dont forget to set their SM values to zero
+ {
+ float max = -INFINITY;
+ ggml_vec_max_f32(masked_begin, &max, S);
+
+ ggml_float sum = 0.0;
+ {
+#ifdef GGML_SOFT_MAX_ACCELERATE
+ max = -max;
+ vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
+ vvexpf(SM, SM, &Mup);
+ ggml_vec_sum_f32(Mup, &sum, SM);
+#else
+ sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
+#endif
+ }
+
+ assert(sum > 0.0);
+
+ sum = 1.0/sum;
+ ggml_vec_scale_f32(masked_begin, SM, sum);
+
+ }
+
+ // step-by-step explanation
+ {
+ // forward-process shape grads from backward process
+ // parallel_for ik2,ik3:
+ // for irep:
+ // iq2 = ik2 + irep*nek2
+ // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
+ // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
+ // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
+ // for iq1:
+ // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
+ // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
+ // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
+ // S0 = -Inf [D,1,1,1]
+ // ~S1[i] = dot(kcur[:D,i], qcur)
+ // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
+ // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
+ // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
+ // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
+ // ~S5[i] = dot(vcur[:,i], S4)
+ // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
+ // ~dst[i,iq1,iq2,iq3] = S5[i] ^
+ // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
+ // dst backward-/ grad[dst] = d
+ //
+ // output gradients with their dependencies:
+ //
+ // grad[kcur] = grad[S1].T @ qcur
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
+ // grad[S4] = grad[S5] @ vcur
+ // grad[S4] = d[:D,id1,id2,id3] @ vcur
+ // grad[qcur] = grad[S1] @ kcur
+ // grad[vcur] = grad[S5].T @ S4
+ // grad[vcur] = d[:D,id1,id2,id3].T @ S4
+ //
+ // in post-order:
+ //
+ // S1 = qcur @ kcur.T
+ // S2 = S1 * scale
+ // S3 = diag_mask_inf(S2, P)
+ // S4 = softmax(S3)
+ // grad[S4] = d[:D,id1,id2,id3] @ vcur
+ // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
+ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
+ // grad[qcur] = grad[S1] @ kcur
+ // grad[kcur] = grad[S1].T @ qcur
+ // grad[vcur] = d[:D,id1,id2,id3].T @ S4
+ //
+ // using less variables (SM=S4):
+ //
+ // S = diag_mask_inf(qcur @ kcur.T * scale, P)
+ // SM = softmax(S)
+ // S = d[:D,iq1,iq2,iq3] @ vcur
+ // dot_SM_gradSM = dot(SM, S)
+ // S = SM * (S - dot(SM, S))
+ // S = diag_mask_zero(S, P) * scale
+ //
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
+ // grad[k][:D,:M,ik2,ik3] += S.T @ qcur
+ // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
+ }
+
+ // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
+ // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
+ // for ic:
+ // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
+ // exclude known future zero S[..] values from operation
+ ggml_vec_set_f32(masked_begin, S, 0);
+ for (int64_t ic = 0; ic < D; ++ic) {
+ ggml_vec_mad_f32(masked_begin,
+ S,
+ (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
+ *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
+ }
+
+ // S = SM * (S - dot(SM, S))
+ float dot_SM_gradSM = 0;
+ ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
+ ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
+ ggml_vec_mul_f32 (masked_begin, S, S, SM);
+
+ // S = diag_mask_zero(S, P) * scale
+ // already done by above ggml_vec_set_f32
+
+ // exclude known zero S[..] values from operation
+ ggml_vec_scale_f32(masked_begin, S, scale);
+
+ // S shape [M,1]
+ // SM shape [M,1]
+ // kcur shape [D,M]
+ // qcur shape [D,1]
+ // vcur shape [M,D]
+
+ // grad[q][:D,iq1,iq2,iq3] += S @ kcur
+ // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
+ // for ic:
+ // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
+ // exclude known zero S[..] values from loop
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
+ ggml_vec_mad_f32(D,
+ (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
+ (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
+ S[ic]);
+ }
+
+ // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
+ // for ic:
+ // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
+ // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
+ // exclude known zero S[..] values from loop
+ for (int64_t ic = 0; ic < masked_begin; ++ic) {
+ ggml_vec_mad_f32(D,
+ (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
+ (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
+ S[ic]);
+ }
+
+ // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
+ // for ic:
+ // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
+ // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
+ // exclude known zero SM[..] values from mad
+ for (int64_t ic = 0; ic < D; ++ic) {
+ ggml_vec_mad_f32(masked_begin,
+ (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
+ SM,
+ *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_flash_attn_back(
+ const struct ggml_compute_params * params,
+ const bool masked,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * q = dst->src[0];
+
+ switch (q->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_ssm_conv
+
+static void ggml_compute_forward_ssm_conv_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ const struct ggml_tensor * src0 = dst->src[0]; // conv_state
+ const struct ggml_tensor * src1 = dst->src[1]; // x
+ const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
+ const struct ggml_tensor * src3 = dst->src[3]; // state_seq
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src2->ne[0]; // d_conv
+ const int nr = src0->ne[1]; // d_inner
+ const int n_t = src1->ne[1]; // n_tokens
+ const int n_kv = src0->ne[2]; // max number of sequences in the batch
+
+ GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
+ GGML_ASSERT(src2->nb[0] == sizeof(float));
+ GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
+ GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
+ // for use with the destination state offset between sequences
+ GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+ const int ir = ir1 - ir0;
+
+ if (n_kv > 1) {
+ // multiple sequences means it's hard to know when it's the first time a state is read,
+ // so copy them all over to the destination, just to be sure.
+ for (int i3 = 0; i3 < n_kv; ++i3) {
+ float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
+ float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
+ // can't use memcpy because of d_conv vs d_conv - 1
+ for (int i1 = 0; i1 < ir; ++i1) {
+ for (int i0 = 0; i0 < nc - 1; ++i0) {
+ // copy s0 to last (d_conv - 1) columns of s
+ s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
+ }
+ }
+ }
+ }
+
+ for (int i2 = 0; i2 < n_t; ++i2) {
+ int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
+ float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
+ float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
+ float * s0; // {d_conv - 1, d_inner, n_kv}
+ float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
+ float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
+ int ne0s0;
+
+ GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
+
+ // avoid needing to copy the state for the first token
+ if (i2 == 0) {
+ s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
+ ne0s0 = src0->ne[0];
+ } else {
+ // the source is the last (d_conv - 1) columns of the destination
+ s0 = s + 1;
+ ne0s0 = nc;
+ }
+
+ // d_inner
+ for (int i1 = 0; i1 < ir; ++i1) {
+ // shift state left
+ for (int i0 = 0; i0 < nc - 1; ++i0) {
+ s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
+ }
+ // insert x on the last column
+ s[(nc - 1) + i1*nc] = x0[i1];
+ }
+
+ // handle copies when there are multiple output states
+ for (int i3 = 1; i3 < n_kv; ++i3) {
+ int32_t seq = sq[i3];
+ if (0 <= seq && seq < n_kv) {
+ float * s1 = s + (seq - sq[0])*nc*nr;
+ memcpy(s1, s, nc*ir*sizeof(float));
+ } else {
+ // stop at negative or too big seq_ids
+ break;
+ }
+ }
+
+ // it seems a little faster when this is separate from the state shift
+ for (int i1 = 0; i1 < ir; ++i1) {
+ // rowwise dot product
+ float sumf = 0.0f;
+ for (int i0 = 0; i0 < nc; ++i0) {
+ int i = i0 + i1*nc;
+ sumf += s[i] * c[i];
+ }
+ x[i1] = sumf;
+ }
+ }
+}
+
+static void ggml_compute_forward_ssm_conv(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ switch (dst->src[0]->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_ssm_conv_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_ssm_scan
+
+static void ggml_compute_forward_ssm_scan_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ const struct ggml_tensor * src0 = dst->src[0]; // s
+ const struct ggml_tensor * src1 = dst->src[1]; // x
+ const struct ggml_tensor * src2 = dst->src[2]; // dt
+ const struct ggml_tensor * src3 = dst->src[3]; // A
+ const struct ggml_tensor * src4 = dst->src[4]; // B
+ const struct ggml_tensor * src5 = dst->src[5]; // C
+ const struct ggml_tensor * src6 = dst->src[6]; // sq
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t nc = src0->ne[0]; // d_state
+ const int64_t nr = src0->ne[1]; // d_inner
+ const int64_t n_t = src1->ne[1]; // number of tokens in the batch
+ const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
+
+ GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
+ GGML_ASSERT(src2->nb[0] == sizeof(float));
+ GGML_ASSERT(src3->nb[0] == sizeof(float));
+ GGML_ASSERT(src4->nb[0] == sizeof(float));
+ GGML_ASSERT(src5->nb[0] == sizeof(float));
+ // required for the dot product between s and C, and when copying the states
+ GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
+ // required for per-sequence offsets for states
+ GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
+ // required to get correct offset for state destination (i.e. src1->nb[2])
+ GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+ const int ir = ir1 - ir0;
+
+ if (n_kv > 1) {
+ // it's hard to know if the source states have already been copied
+ // when there are multiple, so copy them already.
+ for (int i3 = 0; i3 < n_kv; ++i3) {
+ float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
+ memcpy(s, s0, nc*ir*sizeof(float));
+ }
+ }
+
+ for (int i2 = 0; i2 < n_t; ++i2) {
+ int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
+ float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
+ float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
+ float * s0;
+ float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
+ float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
+ float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
+ float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
+ float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
+
+ GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
+
+ // avoid needing to copy the state for the first token
+ if (i2 == 0) {
+ s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
+ } else {
+ // otherwise the source is the same as the destination
+ s0 = s;
+ }
+
+ // d_inner
+ for (int i1 = 0; i1 < ir; ++i1) {
+ // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
+ float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
+ float x_dt = x[i1] * dt_soft_plus;
+ float sumf = 0.0f;
+ // d_state
+ for (int i0 = 0; i0 < nc; ++i0) {
+ int i = i0 + i1*nc;
+ // state = prev_state * dA + dB * x
+ float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
+ // y = rowwise_dotprod(state, C)
+ sumf += state * C[i0];
+ s[i] = state;
+ }
+ y[i1] = sumf;
+ }
+
+ // handle copies when there are multiple output states
+ for (int i3 = 1; i3 < n_kv; ++i3) {
+ int32_t seq = sq[i3];
+ if (0 <= seq && seq < n_kv) {
+ float * s1 = s + (seq - sq[0])*nc*nr;
+ memcpy(s1, s, nc*ir*sizeof(float));
+ } else {
+ // stop at negative or too big seq_ids
+ break;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_ssm_scan(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ switch (dst->src[0]->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_ssm_scan_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_win_part
+
+static void ggml_compute_forward_win_part_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ UNUSED(params);
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+
+ const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
+ const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
+ const int32_t w = ((const int32_t *)(dst->op_params))[2];
+
+ assert(ne00 == ne0);
+ assert(ne3 == nep0*nep1);
+
+ // TODO: optimize / multi-thread
+ for (int py = 0; py < nep1; ++py) {
+ for (int px = 0; px < nep0; ++px) {
+ const int64_t i3 = py*nep0 + px;
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
+ for (int64_t i1 = 0; i1 < ne1; ++i1) {
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
+ const int64_t i02 = py*w + i2;
+ const int64_t i01 = px*w + i1;
+ const int64_t i00 = i0;
+
+ const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0;
+ const int64_t j = i02*ne01*ne00 + i01*ne00 + i00;
+
+ if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
+ ((float *) dst->data)[i] = 0.0f;
+ } else {
+ ((float *) dst->data)[i] = ((float *) src0->data)[j];
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_win_part(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_win_part_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_win_unpart
+
+static void ggml_compute_forward_win_unpart_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ UNUSED(params);
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+
+ const int32_t w = ((const int32_t *)(dst->op_params))[0];
+
+ // padding
+ const int px = (w - ne1%w)%w;
+ //const int py = (w - ne2%w)%w;
+
+ const int npx = (px + ne1)/w;
+ //const int npy = (py + ne2)/w;
+
+ assert(ne0 == ne00);
+
+ // TODO: optimize / multi-thread
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
+ for (int64_t i1 = 0; i1 < ne1; ++i1) {
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
+ const int ip2 = i2/w;
+ const int ip1 = i1/w;
+
+ const int64_t i02 = i2%w;
+ const int64_t i01 = i1%w;
+ const int64_t i00 = i0;
+
+ const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
+ const int64_t j = i2*ne1*ne0 + i1*ne0 + i0;
+
+ ((float *) dst->data)[j] = ((float *) src0->data)[i];
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_win_unpart(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_win_unpart_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+//gmml_compute_forward_unary
+
+static void ggml_compute_forward_unary(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const enum ggml_unary_op op = ggml_get_unary_op(dst);
+
+ switch (op) {
+ case GGML_UNARY_OP_ABS:
+ {
+ ggml_compute_forward_abs(params, dst);
+ } break;
+ case GGML_UNARY_OP_SGN:
+ {
+ ggml_compute_forward_sgn(params, dst);
+ } break;
+ case GGML_UNARY_OP_NEG:
+ {
+ ggml_compute_forward_neg(params, dst);
+ } break;
+ case GGML_UNARY_OP_STEP:
+ {
+ ggml_compute_forward_step(params, dst);
+ } break;
+ case GGML_UNARY_OP_TANH:
+ {
+ ggml_compute_forward_tanh(params, dst);
+ } break;
+ case GGML_UNARY_OP_ELU:
+ {
+ ggml_compute_forward_elu(params, dst);
+ } break;
+ case GGML_UNARY_OP_RELU:
+ {
+ ggml_compute_forward_relu(params, dst);
+ } break;
+ case GGML_UNARY_OP_SIGMOID:
+ {
+ ggml_compute_forward_sigmoid(params, dst);
+ } break;
+ case GGML_UNARY_OP_GELU:
+ {
+ ggml_compute_forward_gelu(params, dst);
+ } break;
+ case GGML_UNARY_OP_GELU_QUICK:
+ {
+ ggml_compute_forward_gelu_quick(params, dst);
+ } break;
+ case GGML_UNARY_OP_SILU:
+ {
+ ggml_compute_forward_silu(params, dst);
+ } break;
+ case GGML_UNARY_OP_HARDSWISH:
+ {
+ ggml_compute_forward_hardswish(params, dst);
+ } break;
+ case GGML_UNARY_OP_HARDSIGMOID:
+ {
+ ggml_compute_forward_hardsigmoid(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_get_rel_pos
+
+static void ggml_compute_forward_get_rel_pos_f16(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+ UNUSED(params);
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ const int64_t w = ne1;
+
+ ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
+ ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
+
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
+ for (int64_t i1 = 0; i1 < ne1; ++i1) {
+ const int64_t pos = (w - i1 - 1) + i2;
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
+ dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_get_rel_pos(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_BF16:
+ {
+ ggml_compute_forward_get_rel_pos_f16(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_add_rel_pos
+
+static void ggml_compute_forward_add_rel_pos_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+ const struct ggml_tensor * src2 = dst->src[2];
+
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
+ if (!inplace) {
+ if (params->ith == 0) {
+ memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
+ }
+ ggml_barrier(params->shared);
+ }
+ // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
+
+ float * src1_data = (float *) src1->data;
+ float * src2_data = (float *) src2->data;
+ float * dst_data = (float *) dst->data;
+
+ const int64_t ne10 = src1->ne[0];
+ const int64_t ne11 = src1->ne[1];
+ const int64_t ne12 = src1->ne[2];
+ const int64_t ne13 = src1->ne[3];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ // total patches in dst
+ const int np = ne13;
+
+ // patches per thread
+ const int dp = (np + nth - 1)/nth;
+
+ // patch range for this thread
+ const int ip0 = dp*ith;
+ const int ip1 = MIN(ip0 + dp, np);
+
+ for (int64_t i13 = ip0; i13 < ip1; ++i13) {
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
+ const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
+ const int64_t jp0 = jp1 + i10;
+ const float src1_e = src1_data[jp0];
+ const float src2_e = src2_data[jp0];
+
+ const int64_t jdh = jp0 * ne10;
+ const int64_t jdw = jdh - (ne10 - 1) * i10;
+
+ for (int64_t j = 0; j < ne10; ++j) {
+ dst_data[jdh + j ] += src2_e;
+ dst_data[jdw + j*ne10] += src1_e;
+ }
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_add_rel_pos(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_add_rel_pos_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_map_unary
+
+static void ggml_compute_forward_map_unary_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst,
+ const ggml_unary_op_f32_t fun) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ for (int i = 0; i < n; i++) {
+ fun(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_map_unary(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst,
+ const ggml_unary_op_f32_t fun) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_map_unary_f32(params, dst, fun);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_map_binary
+
+static void ggml_compute_forward_map_binary_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst,
+ const ggml_binary_op_f32_t fun) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ assert(ggml_is_contiguous_1(src0));
+ assert(ggml_is_contiguous_1(src1));
+ assert(ggml_is_contiguous_1(dst));
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ for (int i = 0; i < n; i++) {
+ fun(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_map_binary(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst,
+ const ggml_binary_op_f32_t fun) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_map_binary_f32(params, dst, fun);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_map_custom1
+
+static void ggml_compute_forward_map_custom1_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst,
+ const ggml_custom1_op_f32_t fun) {
+
+ const struct ggml_tensor * a = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ fun(dst, a);
+}
+
+// ggml_compute_forward_map_custom2
+
+static void ggml_compute_forward_map_custom2_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst,
+ const ggml_custom2_op_f32_t fun) {
+
+ const struct ggml_tensor * a = dst->src[0];
+ const struct ggml_tensor * b = dst->src[1];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ fun(dst, a, b);
+}
+
+// ggml_compute_forward_map_custom3
+
+static void ggml_compute_forward_map_custom3_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst,
+ const ggml_custom3_op_f32_t fun) {
+
+ const struct ggml_tensor * a = dst->src[0];
+ const struct ggml_tensor * b = dst->src[1];
+ const struct ggml_tensor * c = dst->src[1];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ fun(dst, a, b, c);
+}
+
+// ggml_compute_forward_map_custom1
+
+static void ggml_compute_forward_map_custom1(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * a = dst->src[0];
+
+ struct ggml_map_custom1_op_params p;
+ memcpy(&p, dst->op_params, sizeof(p));
+
+ p.fun(dst, a, params->ith, params->nth, p.userdata);
+}
+
+// ggml_compute_forward_map_custom2
+
+static void ggml_compute_forward_map_custom2(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * a = dst->src[0];
+ const struct ggml_tensor * b = dst->src[1];
+
+ struct ggml_map_custom2_op_params p;
+ memcpy(&p, dst->op_params, sizeof(p));
+
+ p.fun(dst, a, b, params->ith, params->nth, p.userdata);
+}
+
+// ggml_compute_forward_map_custom3
+
+static void ggml_compute_forward_map_custom3(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * a = dst->src[0];
+ const struct ggml_tensor * b = dst->src[1];
+ const struct ggml_tensor * c = dst->src[2];
+
+ struct ggml_map_custom3_op_params p;
+ memcpy(&p, dst->op_params, sizeof(p));
+
+ p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
+}
+
+// ggml_compute_forward_cross_entropy_loss
+
+static void ggml_compute_forward_cross_entropy_loss_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+ GGML_ASSERT(ggml_is_scalar(dst));
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ float * sums = (float *) params->wdata;
+
+ // TODO: handle transposed/permuted matrices
+ const int nc = src0->ne[0];
+ const int nr = ggml_nrows(src0);
+
+ GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
+
+ if (ith == 0) {
+ memset(sums, 0, sizeof(float) * (nth + nth * nc));
+ }
+ ggml_barrier(params->shared);
+
+ const double eps = 1e-9;
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
+ float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
+ float * st = ((float *) params->wdata) + nth + ith*nc;
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ //printf("p[%d] = %f\n", i, p[i]);
+ assert(!isnan(s0[i]));
+ assert(!isnan(s1[i]));
+ }
+#endif
+
+ // soft_max
+ float max = -INFINITY;
+ ggml_vec_max_f32(nc, &max, s0);
+ ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max);
+ assert(sum > 0.0);
+ sum = (1.0 - eps) / sum;
+
+ // avoid log(0) by rescaling from [0..1] to [eps..1]
+ ggml_vec_scale_f32(nc, st, sum);
+ ggml_vec_add1_f32(nc, st, st, eps);
+ ggml_vec_log_f32(nc, st, st);
+ ggml_vec_mul_f32(nc, st, st, s1);
+
+ float st_sum = 0;
+ ggml_vec_sum_f32(nc, &st_sum, st);
+ sums[ith] += st_sum;
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ assert(!isnan(st[i]));
+ assert(!isinf(st[i]));
+ }
+#endif
+ }
+ ggml_barrier(params->shared);
+
+ if (ith == 0) {
+ float * dp = (float *) dst->data;
+ ggml_vec_sum_f32(nth, dp, sums);
+ dp[0] *= -1.0f / (float) nr;
+ }
+}
+
+static void ggml_compute_forward_cross_entropy_loss(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_cross_entropy_loss_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_cross_entropy_loss_back
+
+static void ggml_compute_forward_cross_entropy_loss_back_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+ const struct ggml_tensor * opt0 = dst->src[2];
+
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+ GGML_ASSERT(ggml_is_contiguous(opt0));
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ const int64_t ith = params->ith;
+ const int64_t nth = params->nth;
+
+ const double eps = 1e-9;
+
+ // TODO: handle transposed/permuted matrices
+ const int64_t nc = src0->ne[0];
+ const int64_t nr = ggml_nrows(src0);
+
+ // rows per thread
+ const int64_t dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int64_t ir0 = dr*ith;
+ const int64_t ir1 = MIN(ir0 + dr, nr);
+
+ float * d = (float *) opt0->data;
+
+ for (int64_t i1 = ir0; i1 < ir1; i1++) {
+ float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
+ float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
+ float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ //printf("p[%d] = %f\n", i, p[i]);
+ assert(!isnan(s0[i]));
+ assert(!isnan(s1[i]));
+ }
+#endif
+
+ // soft_max
+ float max = -INFINITY;
+ ggml_vec_max_f32(nc, &max, s0);
+ ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
+ assert(sum > 0.0);
+ sum = (1.0 - eps) / sum;
+
+ // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
+ ggml_vec_scale_f32(nc, ds0, sum);
+ ggml_vec_add1_f32(nc, ds0, ds0, eps);
+ ggml_vec_sub_f32(nc, ds0, ds0, s1);
+ ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
+
+#ifndef NDEBUG
+ for (int i = 0; i < nc; ++i) {
+ assert(!isnan(ds0[i]));
+ assert(!isinf(ds0[i]));
+ }
+#endif
+ }
+}
+
+static void ggml_compute_forward_cross_entropy_loss_back(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+/////////////////////////////////
+
+static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
+ GGML_ASSERT(params);
+
+ if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
+ return;
+ }
+
+ switch (tensor->op) {
+ case GGML_OP_DUP:
+ {
+ ggml_compute_forward_dup(params, tensor);
+ } break;
+ case GGML_OP_ADD:
+ {
+ ggml_compute_forward_add(params, tensor);
+ } break;
+ case GGML_OP_ADD1:
+ {
+ ggml_compute_forward_add1(params, tensor);
+ } break;
+ case GGML_OP_ACC:
+ {
+ ggml_compute_forward_acc(params, tensor);
+ } break;
+ case GGML_OP_SUB:
+ {
+ ggml_compute_forward_sub(params, tensor);
+ } break;
+ case GGML_OP_MUL:
+ {
+ ggml_compute_forward_mul(params, tensor);
+ } break;
+ case GGML_OP_DIV:
+ {
+ ggml_compute_forward_div(params, tensor);
+ } break;
+ case GGML_OP_SQR:
+ {
+ ggml_compute_forward_sqr(params, tensor);
+ } break;
+ case GGML_OP_SQRT:
+ {
+ ggml_compute_forward_sqrt(params, tensor);
+ } break;
+ case GGML_OP_LOG:
+ {
+ ggml_compute_forward_log(params, tensor);
+ } break;
+ case GGML_OP_SUM:
+ {
+ ggml_compute_forward_sum(params, tensor);
+ } break;
+ case GGML_OP_SUM_ROWS:
+ {
+ ggml_compute_forward_sum_rows(params, tensor);
+ } break;
+ case GGML_OP_MEAN:
+ {
+ ggml_compute_forward_mean(params, tensor);
+ } break;
+ case GGML_OP_ARGMAX:
+ {
+ ggml_compute_forward_argmax(params, tensor);
+ } break;
+ case GGML_OP_REPEAT:
+ {
+ ggml_compute_forward_repeat(params, tensor);
+ } break;
+ case GGML_OP_REPEAT_BACK:
+ {
+ ggml_compute_forward_repeat_back(params, tensor);
+ } break;
+ case GGML_OP_CONCAT:
+ {
+ ggml_compute_forward_concat(params, tensor);
+ } break;
+ case GGML_OP_SILU_BACK:
+ {
+ ggml_compute_forward_silu_back(params, tensor);
+ } break;
+ case GGML_OP_NORM:
+ {
+ ggml_compute_forward_norm(params, tensor);
+ } break;
+ case GGML_OP_RMS_NORM:
+ {
+ ggml_compute_forward_rms_norm(params, tensor);
+ } break;
+ case GGML_OP_RMS_NORM_BACK:
+ {
+ ggml_compute_forward_rms_norm_back(params, tensor);
+ } break;
+ case GGML_OP_GROUP_NORM:
+ {
+ ggml_compute_forward_group_norm(params, tensor);
+ } break;
+ case GGML_OP_MUL_MAT:
+ {
+ ggml_compute_forward_mul_mat(params, tensor);
+ } break;
+ case GGML_OP_MUL_MAT_ID:
+ {
+ ggml_compute_forward_mul_mat_id(params, tensor);
+ } break;
+ case GGML_OP_OUT_PROD:
+ {
+ ggml_compute_forward_out_prod(params, tensor);
+ } break;
+ case GGML_OP_SCALE:
+ {
+ ggml_compute_forward_scale(params, tensor);
+ } break;
+ case GGML_OP_SET:
+ {
+ ggml_compute_forward_set(params, tensor);
+ } break;
+ case GGML_OP_CPY:
+ {
+ ggml_compute_forward_cpy(params, tensor);
+ } break;
+ case GGML_OP_CONT:
+ {
+ ggml_compute_forward_cont(params, tensor);
+ } break;
+ case GGML_OP_RESHAPE:
+ {
+ ggml_compute_forward_reshape(params, tensor);
+ } break;
+ case GGML_OP_VIEW:
+ {
+ ggml_compute_forward_view(params, tensor);
+ } break;
+ case GGML_OP_PERMUTE:
+ {
+ ggml_compute_forward_permute(params, tensor);
+ } break;
+ case GGML_OP_TRANSPOSE:
+ {
+ ggml_compute_forward_transpose(params, tensor);
+ } break;
+ case GGML_OP_GET_ROWS:
+ {
+ ggml_compute_forward_get_rows(params, tensor);
+ } break;
+ case GGML_OP_GET_ROWS_BACK:
+ {
+ ggml_compute_forward_get_rows_back(params, tensor);
+ } break;
+ case GGML_OP_DIAG:
+ {
+ ggml_compute_forward_diag(params, tensor);
+ } break;
+ case GGML_OP_DIAG_MASK_INF:
+ {
+ ggml_compute_forward_diag_mask_inf(params, tensor);
+ } break;
+ case GGML_OP_DIAG_MASK_ZERO:
+ {
+ ggml_compute_forward_diag_mask_zero(params, tensor);
+ } break;
+ case GGML_OP_SOFT_MAX:
+ {
+ ggml_compute_forward_soft_max(params, tensor);
+ } break;
+ case GGML_OP_SOFT_MAX_BACK:
+ {
+ ggml_compute_forward_soft_max_back(params, tensor);
+ } break;
+ case GGML_OP_ROPE:
+ {
+ ggml_compute_forward_rope(params, tensor);
+ } break;
+ case GGML_OP_ROPE_BACK:
+ {
+ ggml_compute_forward_rope_back(params, tensor);
+ } break;
+ case GGML_OP_CLAMP:
+ {
+ ggml_compute_forward_clamp(params, tensor);
+ } break;
+ case GGML_OP_CONV_TRANSPOSE_1D:
+ {
+ ggml_compute_forward_conv_transpose_1d(params, tensor);
+ } break;
+ case GGML_OP_IM2COL:
+ {
+ ggml_compute_forward_im2col(params, tensor);
+ } break;
+ case GGML_OP_CONV_TRANSPOSE_2D:
+ {
+ ggml_compute_forward_conv_transpose_2d(params, tensor);
+ } break;
+ case GGML_OP_POOL_1D:
+ {
+ ggml_compute_forward_pool_1d(params, tensor);
+ } break;
+ case GGML_OP_POOL_2D:
+ {
+ ggml_compute_forward_pool_2d(params, tensor);
+ } break;
+ case GGML_OP_UPSCALE:
+ {
+ ggml_compute_forward_upscale(params, tensor);
+ } break;
+ case GGML_OP_PAD:
+ {
+ ggml_compute_forward_pad(params, tensor);
+ } break;
+ case GGML_OP_ARANGE:
+ {
+ ggml_compute_forward_arange(params, tensor);
+ } break;
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ {
+ ggml_compute_forward_timestep_embedding(params, tensor);
+ } break;
+ case GGML_OP_ARGSORT:
+ {
+ ggml_compute_forward_argsort(params, tensor);
+ } break;
+ case GGML_OP_LEAKY_RELU:
+ {
+ ggml_compute_forward_leaky_relu(params, tensor);
+ } break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ {
+ ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
+ } break;
+ case GGML_OP_FLASH_ATTN_BACK:
+ {
+ int32_t t = ggml_get_op_params_i32(tensor, 0);
+ GGML_ASSERT(t == 0 || t == 1);
+ bool masked = t != 0;
+ ggml_compute_forward_flash_attn_back(params, masked, tensor);
+ } break;
+ case GGML_OP_SSM_CONV:
+ {
+ ggml_compute_forward_ssm_conv(params, tensor);
+ } break;
+ case GGML_OP_SSM_SCAN:
+ {
+ ggml_compute_forward_ssm_scan(params, tensor);
+ } break;
+ case GGML_OP_WIN_PART:
+ {
+ ggml_compute_forward_win_part(params, tensor);
+ } break;
+ case GGML_OP_WIN_UNPART:
+ {
+ ggml_compute_forward_win_unpart(params, tensor);
+ } break;
+ case GGML_OP_UNARY:
+ {
+ ggml_compute_forward_unary(params, tensor);
+ } break;
+ case GGML_OP_GET_REL_POS:
+ {
+ ggml_compute_forward_get_rel_pos(params, tensor);
+ } break;
+ case GGML_OP_ADD_REL_POS:
+ {
+ ggml_compute_forward_add_rel_pos(params, tensor);
+ } break;
+ case GGML_OP_MAP_UNARY:
+ {
+ ggml_unary_op_f32_t fun;
+ memcpy(&fun, tensor->op_params, sizeof(fun));
+ ggml_compute_forward_map_unary(params, tensor, fun);
+ }
+ break;
+ case GGML_OP_MAP_BINARY:
+ {
+ ggml_binary_op_f32_t fun;
+ memcpy(&fun, tensor->op_params, sizeof(fun));
+ ggml_compute_forward_map_binary(params, tensor, fun);
+ }
+ break;
+ case GGML_OP_MAP_CUSTOM1_F32:
+ {
+ ggml_custom1_op_f32_t fun;
+ memcpy(&fun, tensor->op_params, sizeof(fun));
+ ggml_compute_forward_map_custom1_f32(params, tensor, fun);
+ }
+ break;
+ case GGML_OP_MAP_CUSTOM2_F32:
+ {
+ ggml_custom2_op_f32_t fun;
+ memcpy(&fun, tensor->op_params, sizeof(fun));
+ ggml_compute_forward_map_custom2_f32(params, tensor, fun);
+ }
+ break;
+ case GGML_OP_MAP_CUSTOM3_F32:
+ {
+ ggml_custom3_op_f32_t fun;
+ memcpy(&fun, tensor->op_params, sizeof(fun));
+ ggml_compute_forward_map_custom3_f32(params, tensor, fun);
+ }
+ break;
+ case GGML_OP_MAP_CUSTOM1:
+ {
+ ggml_compute_forward_map_custom1(params, tensor);
+ }
+ break;
+ case GGML_OP_MAP_CUSTOM2:
+ {
+ ggml_compute_forward_map_custom2(params, tensor);
+ }
+ break;
+ case GGML_OP_MAP_CUSTOM3:
+ {
+ ggml_compute_forward_map_custom3(params, tensor);
+ }
+ break;
+ case GGML_OP_CROSS_ENTROPY_LOSS:
+ {
+ ggml_compute_forward_cross_entropy_loss(params, tensor);
+ }
+ break;
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+ {
+ ggml_compute_forward_cross_entropy_loss_back(params, tensor);
+ }
+ break;
+ case GGML_OP_NONE:
+ {
+ // nop
+ } break;
+ case GGML_OP_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+static size_t ggml_hash_size(size_t min_sz) {
+ // next primes after powers of two
+ static const size_t primes[] = {
+ 2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031,
+ 2053, 4099, 8209, 16411, 32771, 65537, 131101,
+ 262147, 524309, 1048583, 2097169, 4194319, 8388617,
+ 16777259, 33554467, 67108879, 134217757, 268435459,
+ 536870923, 1073741827, 2147483659
+ };
+ static const size_t n_primes = sizeof(primes)/sizeof(primes[0]);
+
+ // find the smallest prime that is larger or equal to min_sz
+ size_t l = 0;
+ size_t r = n_primes;
+ while (l < r) {
+ size_t m = (l + r)/2;
+ if (primes[m] < min_sz) {
+ l = m + 1;
+ } else {
+ r = m;
+ }
+ }
+ size_t sz = l < n_primes ? primes[l] : min_sz | 1;
+ return sz;
+}
+
+static size_t ggml_hash(const void * p) {
+ return (size_t)p;
+}
+
+size_t ggml_hash_find(const struct ggml_hash_set hash_set, struct ggml_tensor * key) {
+ size_t h = ggml_hash(key) % hash_set.size;
+
+ // linear probing
+ size_t i = h;
+ while (hash_set.keys[i] != NULL && hash_set.keys[i] != key) {
+ i = (i + 1) % hash_set.size;
+ if (i == h) {
+ // visited all hash table entries -> not found
+ return GGML_HASHTABLE_FULL;
+ }
+ }
+ return i;
+}
+
+bool ggml_hash_contains(struct ggml_hash_set hash_set, struct ggml_tensor * key) {
+ size_t i = ggml_hash_find(hash_set, key);
+ return i != GGML_HASHTABLE_FULL && hash_set.keys[i] == key;
+}
+
+size_t ggml_hash_insert(struct ggml_hash_set hash_set, struct ggml_tensor * key) {
+ size_t i = ggml_hash_find(hash_set, key);
+
+ GGML_ASSERT(i != GGML_HASHTABLE_FULL);
+
+ if (hash_set.keys[i] == key) {
+ return GGML_HASHTABLE_ALREADY_EXISTS;
+ }
+
+ // insert
+ GGML_ASSERT(hash_set.keys[i] == NULL);
+ hash_set.keys[i] = key;
+ return i;
+}
+
+size_t ggml_hash_find_or_insert(struct ggml_hash_set hash_set, struct ggml_tensor * key) {
+ size_t i = ggml_hash_find(hash_set, key);
+
+ GGML_ASSERT(i != GGML_HASHTABLE_FULL);
+
+ hash_set.keys[i] = key;
+ return i;
+}
+
+struct ggml_hash_set ggml_hash_set_new(size_t size) {
+ size = ggml_hash_size(size);
+ struct ggml_hash_set result;
+ result.size = size;
+ result.keys = GGML_MALLOC(sizeof(struct ggml_tensor *) * size);
+ memset(result.keys, 0, sizeof(struct ggml_tensor *) * size);
+ return result;
+}
+
+static void ggml_hash_set_free(struct ggml_hash_set hash_set) {
+ GGML_FREE(hash_set.keys);
+}
+
+struct hash_map {
+ struct ggml_hash_set set;
+ struct ggml_tensor ** vals;
+};
+
+static struct hash_map * ggml_new_hash_map(size_t size) {
+ struct hash_map * result = GGML_MALLOC(sizeof(struct hash_map));
+ result->set = ggml_hash_set_new(size);
+ result->vals = GGML_MALLOC(sizeof(struct ggml_tensor *) * result->set.size);
+ memset(result->vals, 0, sizeof(struct ggml_tensor *) * result->set.size);
+ return result;
+}
+
+static void ggml_hash_map_free(struct hash_map * map) {
+ ggml_hash_set_free(map->set);
+ GGML_FREE(map->vals);
+ GGML_FREE(map);
+}
+
+// gradient checkpointing
+
+static struct ggml_tensor * ggml_recompute_graph_node(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * graph,
+ struct hash_map * replacements,
+ struct ggml_tensor * node) {
+
+ if (node == NULL) {
+ return NULL;
+ }
+
+ if (node->flags & GGML_TENSOR_FLAG_PARAM) {
+ return node;
+ }
+
+ if (!ggml_hash_contains(graph->visited_hash_table, node)) {
+ return node;
+ }
+
+ int count_children = 0;
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
+ if (node->src[k]) {
+ ++count_children;
+ }
+ }
+
+ if (count_children == 0) {
+ return node;
+ }
+
+ size_t i = ggml_hash_find(replacements->set, node);
+ GGML_ASSERT(i != GGML_HASHTABLE_FULL); // assert that not full
+ if (replacements->set.keys[i] == node) {
+ return replacements->vals[i];
+ }
+
+ struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, GGML_MAX_DIMS, node->ne);
+
+ // insert clone into replacements
+ GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite
+ replacements->set.keys[i] = node;
+ replacements->vals[i] = clone;
+
+ clone->op = node->op;
+ clone->grad = node->grad;
+ clone->flags = node->flags;
+ clone->extra = node->extra;
+ for (int k = 0; k < GGML_MAX_DIMS; ++k) {
+ clone->nb[k] = node->nb[k];
+ }
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
+ clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
+ }
+ if (node->view_src != NULL) {
+ clone->data = (node->view_src->data == NULL)
+ ? NULL // view_src not yet allocated
+ : (char *) node->view_src->data // view_src already allocated
+ + node->view_offs;
+ clone->view_src = node->view_src;
+ clone->view_offs = node->view_offs;
+ }
+
+ GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
+ GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
+ memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
+ ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
+
+ return clone;
+}
+
+void ggml_build_backward_gradient_checkpointing(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb,
+ struct ggml_cgraph * gb_tmp,
+ struct ggml_tensor * * checkpoints,
+ int n_checkpoints) {
+ ggml_graph_cpy(gf, gb_tmp);
+ ggml_build_backward_expand(ctx, gf, gb_tmp, true);
+
+ if (n_checkpoints <= 0) {
+ ggml_graph_cpy(gb_tmp, gb);
+ return;
+ }
+
+ struct hash_map * replacements = ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints);
+
+ // insert checkpoints in replacements
+ for (int i = 0; i < n_checkpoints; ++i) {
+ size_t k = ggml_hash_find(replacements->set, checkpoints[i]);
+ GGML_ASSERT(k != GGML_HASHTABLE_FULL); // assert that not full
+ GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite
+ replacements->set.keys[k] = checkpoints[i];
+ replacements->vals[k] = checkpoints[i];
+ }
+
+ ggml_graph_cpy(gf, gb);
+ // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
+ // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
+ // by recomputing them from checkpoints
+ for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
+ struct ggml_tensor * node = gb_tmp->nodes[i];
+ for (int k = 0; k < GGML_MAX_SRC; ++k) {
+ // insert new tensors recomputing src, reusing already made replacements,
+ // remember replacements: remember new tensors with mapping from corresponding gf nodes
+ // recurse for input tensors,
+ // unless (i.e. terminating when) input tensors are replacements (like checkpoints)
+ node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
+ }
+ // insert rewritten backward node with replacements made into resulting backward graph gb
+ ggml_build_forward_expand(gb, node);
+ }
+
+ ggml_hash_map_free(replacements);
+}
+
+// functions to change gradients considering the case that input a might be initial gradient with zero value
+
+static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) {
+ if (ggml_hash_contains(zero_table, a)) {
+ return b;
+ } else {
+ return ggml_add_impl(ctx, a, b, false);
+ }
+}
+
+static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set zero_table) {
+ if (ggml_hash_contains(zero_table, a)) {
+ struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
+ return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
+ } else {
+ return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
+ }
+}
+
+static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) {
+ if (ggml_hash_contains(zero_table, a)) {
+ return ggml_repeat(ctx, b, a);
+ } else {
+ return ggml_add1_impl(ctx, a, b, false);
+ }
+}
+
+static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) {
+ if (ggml_hash_contains(zero_table, a)) {
+ return ggml_neg(ctx, b);
+ } else {
+ return ggml_sub_impl(ctx, a, b, false);
+ }
+}
+
+static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
+ struct ggml_tensor * src0 = tensor->src[0];
+ struct ggml_tensor * src1 = tensor->src[1];
+ struct ggml_tensor * src2 = tensor->src[2];
+
+ switch (tensor->op) {
+ case GGML_OP_DUP:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ }
+ } break;
+ case GGML_OP_ADD:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ }
+ if (src1->grad) {
+ src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
+ }
+ } break;
+ case GGML_OP_ADD1:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ }
+ if (src1->grad) {
+ src1->grad = ggml_add_or_set(ctx,
+ src1->grad,
+ ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
+ zero_table);
+ }
+ } break;
+ case GGML_OP_ACC:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ }
+ if (src1->grad) {
+ const size_t nb1 = ((int32_t *) tensor->op_params)[0];
+ const size_t nb2 = ((int32_t *) tensor->op_params)[1];
+ const size_t nb3 = ((int32_t *) tensor->op_params)[2];
+ const size_t offset = ((int32_t *) tensor->op_params)[3];
+
+ struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
+ tensor->grad,
+ src1->grad->ne[0],
+ src1->grad->ne[1],
+ src1->grad->ne[2],
+ src1->grad->ne[3],
+ nb1, nb2, nb3, offset);
+
+ src1->grad =
+ ggml_add_or_set(ctx,
+ src1->grad,
+ ggml_reshape(ctx,
+ ggml_cont(ctx, tensor_grad_view),
+ src1->grad),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_SUB:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ }
+ if (src1->grad) {
+ src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
+ }
+ } break;
+ case GGML_OP_MUL:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_mul(ctx, src1, tensor->grad),
+ zero_table);
+ }
+ if (src1->grad) {
+ src1->grad =
+ ggml_add_or_set(ctx,
+ src1->grad,
+ ggml_mul(ctx, src0, tensor->grad),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_DIV:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_div(ctx, tensor->grad, src1),
+ zero_table);
+ }
+ if (src1->grad) {
+ src1->grad =
+ ggml_sub_or_set(ctx,
+ src1->grad,
+ ggml_mul(ctx,
+ tensor->grad,
+ ggml_div(ctx, tensor, src1)),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_SQR:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_scale(ctx,
+ ggml_mul(ctx, src0, tensor->grad),
+ 2.0f),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_SQRT:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_scale(ctx,
+ ggml_div(ctx,
+ tensor->grad,
+ tensor),
+ 0.5f),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_LOG:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_div(ctx,
+ tensor->grad,
+ src0),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_SUM:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add1_or_set(ctx,
+ src0->grad,
+ tensor->grad,
+ zero_table);
+ }
+ } break;
+ case GGML_OP_SUM_ROWS:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_repeat(ctx,
+ tensor->grad,
+ src0->grad),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_MEAN:
+ case GGML_OP_ARGMAX:
+ {
+ GGML_ASSERT(false); // TODO: implement
+ } break;
+ case GGML_OP_REPEAT:
+ {
+ // necessary for llama
+ if (src0->grad) {
+ src0->grad = ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_repeat_back(ctx, tensor->grad, src0->grad),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_REPEAT_BACK:
+ {
+ if (src0->grad) {
+ // TODO: test this
+ src0->grad = ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_repeat(ctx, tensor->grad, src0->grad),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_CONCAT:
+ {
+ GGML_ASSERT(false); // TODO: implement
+ } break;
+ case GGML_OP_SILU_BACK:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_NORM:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_RMS_NORM:
+ {
+ // necessary for llama
+ if (src0->grad) {
+ float eps;
+ memcpy(&eps, tensor->op_params, sizeof(float));
+
+ src0->grad = ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_RMS_NORM_BACK:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_GROUP_NORM:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_MUL_MAT:
+ {
+ // https://cs231n.github.io/optimization-2/#staged
+ // # forward pass
+ // s0 = np.random.randn(5, 10)
+ // s1 = np.random.randn(10, 3)
+ // t = s0.dot(s1)
+
+ // # now suppose we had the gradient on t from above in the circuit
+ // dt = np.random.randn(*t.shape) # same shape as t
+ // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
+ // ds1 = t.T.dot(dt)
+
+ // tensor.shape [m,p,qq,rr]
+ // src0.shape [n,m,q1,r1]
+ // src1.shape [n,p,qq,rr]
+
+ // necessary for llama
+ if (src0->grad) {
+ struct ggml_tensor * s1_tg =
+ ggml_out_prod(ctx, // [n,m,qq,rr]
+ src1, // [n,p,qq,rr]
+ tensor->grad); // [m,p,qq,rr]
+ const int64_t qq = s1_tg->ne[2];
+ const int64_t rr = s1_tg->ne[3];
+ const int64_t q1 = src0->ne[2];
+ const int64_t r1 = src0->ne[3];
+ const bool ne2_broadcasted = qq > q1;
+ const bool ne3_broadcasted = rr > r1;
+ if (ne2_broadcasted || ne3_broadcasted) {
+ // sum broadcast repetitions of s1_tg into shape of src0
+ s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
+ }
+ src0->grad =
+ ggml_add_or_set(ctx,
+ src0->grad, // [n,m,q1,r1]
+ s1_tg, // [n,m,q1,r1]
+ zero_table);
+ }
+ if (src1->grad) {
+ src1->grad =
+ ggml_add_or_set(ctx,
+ src1->grad, // [n,p,qq,rr]
+ // ggml_mul_mat(ctx, // [n,p,qq,rr]
+ // ggml_cont(ctx, // [m,n,q1,r1]
+ // ggml_transpose(ctx, src0)), // [m,n,q1,r1]
+ // tensor->grad), // [m,p,qq,rr]
+
+ // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
+ // // avoid transpose of src0, rather transpose smaller tensor->grad
+ // // and then use ggml_out_prod
+ ggml_out_prod(ctx, // [n,p,qq,rr]
+ src0, // [n,m,q1,r1]
+ ggml_transpose(ctx, // [p,m,qq,rr]
+ tensor->grad)), // [m,p,qq,rr]
+ zero_table);
+ }
+ } break;
+ case GGML_OP_MUL_MAT_ID:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_OUT_PROD:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_SCALE:
+ {
+ // necessary for llama
+ if (src0->grad) {
+ float s;
+ memcpy(&s, tensor->op_params, sizeof(float));
+
+ src0->grad =
+ ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_scale_impl(ctx, tensor->grad, s, false),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_SET:
+ {
+ const size_t nb1 = ((int32_t *) tensor->op_params)[0];
+ const size_t nb2 = ((int32_t *) tensor->op_params)[1];
+ const size_t nb3 = ((int32_t *) tensor->op_params)[2];
+ const size_t offset = ((int32_t *) tensor->op_params)[3];
+
+ struct ggml_tensor * tensor_grad_view = NULL;
+
+ if (src0->grad || src1->grad) {
+ GGML_ASSERT(src0->type == tensor->type);
+ GGML_ASSERT(tensor->grad->type == tensor->type);
+ GGML_ASSERT(tensor->grad->type == src1->grad->type);
+
+ tensor_grad_view = ggml_view_4d(ctx,
+ tensor->grad,
+ src1->grad->ne[0],
+ src1->grad->ne[1],
+ src1->grad->ne[2],
+ src1->grad->ne[3],
+ nb1, nb2, nb3, offset);
+ }
+
+ if (src0->grad) {
+ src0->grad = ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_acc_impl(ctx,
+ tensor->grad,
+ ggml_neg(ctx, tensor_grad_view),
+ nb1, nb2, nb3, offset, false),
+ zero_table);
+ }
+
+ if (src1->grad) {
+ src1->grad =
+ ggml_add_or_set(ctx,
+ src1->grad,
+ ggml_reshape(ctx,
+ ggml_cont(ctx, tensor_grad_view),
+ src1->grad),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_CPY:
+ {
+ // necessary for llama
+ // cpy overwrites value of src1 by src0 and returns view(src1)
+ // the overwriting is mathematically equivalent to:
+ // tensor = src0 * 1 + src1 * 0
+ if (src0->grad) {
+ // dsrc0 = dtensor * 1
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ }
+ if (src1->grad) {
+ // dsrc1 = dtensor * 0 -> noop
+ }
+ } break;
+ case GGML_OP_CONT:
+ {
+ // same as cpy
+ if (src0->grad) {
+ GGML_ASSERT(ggml_is_contiguous(src0->grad));
+ GGML_ASSERT(ggml_is_contiguous(tensor->grad));
+ src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ }
+ } break;
+ case GGML_OP_RESHAPE:
+ {
+ // necessary for llama
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_or_set(ctx, src0->grad,
+ ggml_reshape(ctx,
+ ggml_is_contiguous(tensor->grad)
+ ? tensor->grad
+ : ggml_cont(ctx, tensor->grad),
+ src0->grad),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_VIEW:
+ {
+ // necessary for llama
+ if (src0->grad) {
+ size_t offset;
+
+ memcpy(&offset, tensor->op_params, sizeof(offset));
+
+ size_t nb1 = tensor->nb[1];
+ size_t nb2 = tensor->nb[2];
+ size_t nb3 = tensor->nb[3];
+
+ if (src0->type != src0->grad->type) {
+ // gradient is typically F32, but src0 could be other type
+ size_t ng = ggml_element_size(src0->grad);
+ size_t n0 = ggml_element_size(src0);
+ GGML_ASSERT(offset % n0 == 0);
+ GGML_ASSERT(nb1 % n0 == 0);
+ GGML_ASSERT(nb2 % n0 == 0);
+ GGML_ASSERT(nb3 % n0 == 0);
+ offset = (offset / n0) * ng;
+ nb1 = (nb1 / n0) * ng;
+ nb2 = (nb2 / n0) * ng;
+ nb3 = (nb3 / n0) * ng;
+ }
+
+ src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
+ }
+ } break;
+ case GGML_OP_PERMUTE:
+ {
+ // necessary for llama
+ if (src0->grad) {
+ int32_t * axes = (int32_t *) tensor->op_params;
+ int axis0 = axes[0] & 0x3;
+ int axis1 = axes[1] & 0x3;
+ int axis2 = axes[2] & 0x3;
+ int axis3 = axes[3] & 0x3;
+ int axes_backward[4] = {0,0,0,0};
+ axes_backward[axis0] = 0;
+ axes_backward[axis1] = 1;
+ axes_backward[axis2] = 2;
+ axes_backward[axis3] = 3;
+ src0->grad =
+ ggml_add_or_set(ctx, src0->grad,
+ ggml_permute(ctx,
+ tensor->grad,
+ axes_backward[0],
+ axes_backward[1],
+ axes_backward[2],
+ axes_backward[3]),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_TRANSPOSE:
+ {
+ // necessary for llama
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_or_set(ctx, src0->grad,
+ ggml_transpose(ctx, tensor->grad),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_GET_ROWS:
+ {
+ // necessary for llama (only for tokenizer)
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_or_set(ctx, src0->grad,
+ // last ggml_get_rows_back argument src0->grad is only
+ // necessary to setup correct output shape
+ ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
+ zero_table);
+ }
+ if (src1->grad) {
+ // noop
+ }
+ } break;
+ case GGML_OP_GET_ROWS_BACK:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_DIAG:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_DIAG_MASK_INF:
+ {
+ // necessary for llama
+ if (src0->grad) {
+ const int n_past = ((int32_t *) tensor->op_params)[0];
+ src0->grad =
+ ggml_add_or_set(ctx, src0->grad,
+ /* ggml_diag_mask_inf_impl() shouldn't be here */
+ /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
+ ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_DIAG_MASK_ZERO:
+ {
+ // necessary for llama
+ if (src0->grad) {
+ const int n_past = ((int32_t *) tensor->op_params)[0];
+ src0->grad =
+ ggml_add_or_set(ctx, src0->grad,
+ ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_SOFT_MAX:
+ {
+ // necessary for llama
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_or_set(ctx, src0->grad,
+ ggml_soft_max_back(ctx, tensor->grad, tensor),
+ zero_table);
+ }
+
+ } break;
+ case GGML_OP_SOFT_MAX_BACK:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_ROPE:
+ {
+ // necessary for llama
+ if (src0->grad) {
+ //const int n_past = ((int32_t *) tensor->op_params)[0];
+ const int n_dims = ((int32_t *) tensor->op_params)[1];
+ const int mode = ((int32_t *) tensor->op_params)[2];
+ //const int n_ctx = ((int32_t *) tensor->op_params)[3];
+ const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+
+ memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
+
+ src0->grad = ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_rope_back(ctx,
+ tensor->grad,
+ src1,
+ src2,
+ n_dims,
+ mode,
+ n_ctx_orig,
+ freq_base,
+ freq_scale,
+ ext_factor,
+ attn_factor,
+ beta_fast,
+ beta_slow),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_ROPE_BACK:
+ {
+ if (src0->grad) {
+ //const int n_past = ((int32_t *) tensor->op_params)[0];
+ const int n_dims = ((int32_t *) tensor->op_params)[1];
+ const int mode = ((int32_t *) tensor->op_params)[2];
+ //const int n_ctx = ((int32_t *) tensor->op_params)[3];
+ const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+
+ memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
+
+ src0->grad = ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_rope_impl(ctx,
+ tensor->grad,
+ src1,
+ src2,
+ n_dims,
+ mode,
+ n_ctx_orig,
+ freq_base,
+ freq_scale,
+ ext_factor,
+ attn_factor,
+ beta_fast,
+ beta_slow,
+ false),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_CLAMP:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_CONV_TRANSPOSE_1D:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_IM2COL:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_CONV_TRANSPOSE_2D:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_POOL_1D:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_POOL_2D:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_UPSCALE:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_PAD:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_ARANGE:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_ARGSORT:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_LEAKY_RELU:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ {
+ struct ggml_tensor * flash_grad = NULL;
+ if (src0->grad || src1->grad || tensor->src[2]->grad) {
+ int32_t t = ggml_get_op_params_i32(tensor, 0);
+ GGML_ASSERT(t == 0 || t == 1);
+ bool masked = t != 0;
+ flash_grad =
+ ggml_flash_attn_back(ctx,
+ src0,
+ src1,
+ tensor->src[2],
+ tensor->grad,
+ masked);
+ }
+
+ const int64_t elem_q = ggml_nelements(src0);
+ const int64_t elem_k = ggml_nelements(src1);
+ const int64_t elem_v = ggml_nelements(src2);
+
+ enum ggml_type result_type = flash_grad->type;
+ GGML_ASSERT(ggml_blck_size(result_type) == 1);
+ const size_t tsize = ggml_type_size(result_type);
+
+ const size_t offs_q = 0;
+ const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+ const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
+
+ if (src0->grad) {
+ struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
+ struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0);
+ src0->grad = ggml_add_or_set(ctx,
+ src0->grad,
+ grad_q,
+ zero_table);
+ }
+ if (src1->grad) {
+ struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
+ struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1);
+ src1->grad = ggml_add_or_set(ctx,
+ src1->grad,
+ grad_k,
+ zero_table);
+ }
+ if (src2->grad) {
+ struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
+ struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2);
+ src2->grad = ggml_add_or_set(ctx,
+ src2->grad,
+ grad_v,
+ zero_table);
+ }
+ } break;
+ case GGML_OP_FLASH_ATTN_BACK:
+ {
+ GGML_ASSERT(false); // not supported
+ } break;
+ case GGML_OP_SSM_CONV:
+ case GGML_OP_SSM_SCAN:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_OP_WIN_PART:
+ case GGML_OP_WIN_UNPART:
+ case GGML_OP_UNARY:
+ {
+ switch (ggml_get_unary_op(tensor)) {
+ case GGML_UNARY_OP_ABS:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_mul(ctx,
+ ggml_sgn(ctx, src0),
+ tensor->grad),
+ zero_table);
+ }
+ } break;
+ case GGML_UNARY_OP_SGN:
+ {
+ if (src0->grad) {
+ // noop
+ }
+ } break;
+ case GGML_UNARY_OP_NEG:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
+ }
+ } break;
+ case GGML_UNARY_OP_STEP:
+ {
+ if (src0->grad) {
+ // noop
+ }
+ } break;
+ case GGML_UNARY_OP_TANH:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_UNARY_OP_ELU:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_UNARY_OP_RELU:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_mul(ctx,
+ ggml_step(ctx, src0),
+ tensor->grad),
+ zero_table);
+ }
+ } break;
+ case GGML_UNARY_OP_SIGMOID:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_UNARY_OP_GELU:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_UNARY_OP_GELU_QUICK:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
+ case GGML_UNARY_OP_SILU:
+ {
+ // necessary for llama
+ if (src0->grad) {
+ src0->grad = ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_silu_back(ctx, src0, tensor->grad),
+ zero_table);
+ }
+ } break;
+ default:
+ GGML_ASSERT(false);
+ }
+ } break;
+ case GGML_OP_GET_REL_POS:
+ case GGML_OP_ADD_REL_POS:
+ case GGML_OP_MAP_UNARY:
+ case GGML_OP_MAP_BINARY:
+ case GGML_OP_MAP_CUSTOM1_F32:
+ case GGML_OP_MAP_CUSTOM2_F32:
+ case GGML_OP_MAP_CUSTOM3_F32:
+ case GGML_OP_MAP_CUSTOM1:
+ case GGML_OP_MAP_CUSTOM2:
+ case GGML_OP_MAP_CUSTOM3:
+ {
+ GGML_ASSERT(false); // not supported
+ } break;
+ case GGML_OP_CROSS_ENTROPY_LOSS:
+ {
+ if (src0->grad) {
+ src0->grad = ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_cross_entropy_loss_back(ctx,
+ src0,
+ src1,
+ tensor->grad),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+ {
+ GGML_ASSERT(false); // not supported
+ } break;
+ case GGML_OP_NONE:
+ {
+ // nop
+ } break;
+ case GGML_OP_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
+ if (tensor->src[i] && tensor->src[i]->grad) {
+ GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad));
+ }
+ }
+}
+
+static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
+ if (node->grad == NULL) {
+ // this usually happens when we generate intermediate nodes from constants in the backward pass
+ // it can also happen during forward pass, if the user performs computations with constants
+ if (node->op != GGML_OP_NONE) {
+ //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op);
+ }
+ }
+
+ // check if already visited
+ if (ggml_hash_insert(cgraph->visited_hash_table, node) == GGML_HASHTABLE_ALREADY_EXISTS) {
+ return;
+ }
+
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
+ const int k =
+ (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
+ (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
+ /* unknown order, just fall back to using i*/ i;
+ if (node->src[k]) {
+ ggml_visit_parents(cgraph, node->src[k]);
+ }
+ }
+
+ if (node->op == GGML_OP_NONE && node->grad == NULL) {
+ // reached a leaf node, not part of the gradient graph (e.g. a constant)
+ GGML_ASSERT(cgraph->n_leafs < cgraph->size);
+
+ if (strlen(node->name) == 0) {
+ ggml_format_name(node, "leaf_%d", cgraph->n_leafs);
+ }
+
+ cgraph->leafs[cgraph->n_leafs] = node;
+ cgraph->n_leafs++;
+ } else {
+ GGML_ASSERT(cgraph->n_nodes < cgraph->size);
+
+ if (strlen(node->name) == 0) {
+ ggml_format_name(node, "node_%d", cgraph->n_nodes);
+ }
+
+ cgraph->nodes[cgraph->n_nodes] = node;
+ if (cgraph->grads) {
+ cgraph->grads[cgraph->n_nodes] = node->grad;
+ }
+ cgraph->n_nodes++;
+ }
+}
+
+static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
+ if (!expand) {
+ // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand
+ ggml_graph_clear(cgraph);
+ }
+
+ const int n0 = cgraph->n_nodes;
+ UNUSED(n0);
+
+ ggml_visit_parents(cgraph, tensor);
+
+ const int n_new = cgraph->n_nodes - n0;
+ GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new);
+
+ if (n_new > 0) {
+ // the last added node should always be starting point
+ GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor);
+ }
+}
+
+void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
+ ggml_build_forward_impl(cgraph, tensor, true);
+}
+
+void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) {
+ GGML_ASSERT(gf->n_nodes > 0);
+
+ // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
+ if (keep) {
+ for (int i = 0; i < gf->n_nodes; i++) {
+ struct ggml_tensor * node = gf->nodes[i];
+
+ if (node->grad) {
+ node->grad = ggml_dup_tensor(ctx, node);
+ gf->grads[i] = node->grad;
+ }
+ }
+ }
+
+ // remember original gradients which start with zero values
+ struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
+ for (int i = 0; i < gf->n_nodes; i++) {
+ if (gf->grads[i]) {
+ ggml_hash_insert(zero_table, gf->grads[i]);
+ }
+ }
+
+ for (int i = gf->n_nodes - 1; i >= 0; i--) {
+ struct ggml_tensor * node = gf->nodes[i];
+
+ // inplace operations to add gradients are not created by ggml_compute_backward
+ // use allocator to automatically make inplace operations
+ if (node->grad) {
+ ggml_compute_backward(ctx, node, zero_table);
+ }
+ }
+
+ for (int i = 0; i < gf->n_nodes; i++) {
+ struct ggml_tensor * node = gf->nodes[i];
+
+ if (node->flags & GGML_TENSOR_FLAG_PARAM) {
+ GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
+ ggml_build_forward_expand(gb, node->grad);
+ }
+ }
+
+ ggml_hash_set_free(zero_table);
+}
+
+static size_t ggml_graph_nbytes(size_t size, bool grads) {
+ size_t nbytes = sizeof(struct ggml_cgraph);
+ nbytes += size * sizeof(struct ggml_tensor *) * 2; // leafs + nodes
+ if (grads) {
+ nbytes += size * sizeof(struct ggml_tensor *); // grads
+ }
+ nbytes += ggml_hash_size(size * 2) * sizeof(struct ggml_tensor *); // hash set
+ return nbytes;
+}
+
+size_t ggml_graph_overhead_custom(size_t size, bool grads) {
+ return GGML_OBJECT_SIZE + GGML_PAD(ggml_graph_nbytes(size, grads), GGML_MEM_ALIGN);
+}
+
+size_t ggml_graph_overhead(void) {
+ return ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, false);
+}
+
+struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads) {
+ const size_t obj_size = ggml_graph_nbytes(size, grads);
+ struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_GRAPH, obj_size);
+ struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
+
+ struct ggml_tensor ** data_start = (struct ggml_tensor **) (cgraph + 1);
+
+ size_t hash_size = ggml_hash_size(size * 2);
+ struct ggml_tensor ** nodes_ptr = data_start;
+ struct ggml_tensor ** leafs_ptr = nodes_ptr + size;
+ struct ggml_tensor ** hash_keys_ptr = leafs_ptr + size;
+ struct ggml_tensor ** grads_ptr = grads ? hash_keys_ptr + hash_size : NULL;
+
+ // check that we allocated the correct amount of memory
+ assert(obj_size == (size_t) (
+ (grads ? (char *)(grads_ptr + size) : (char *)(hash_keys_ptr + hash_size)) - (char *)cgraph));
+
+ memset(hash_keys_ptr, 0, hash_size * sizeof(struct ggml_tensor *));
+
+ *cgraph = (struct ggml_cgraph) {
+ /*.size =*/ size,
+ /*.n_nodes =*/ 0,
+ /*.n_leafs =*/ 0,
+ /*.nodes =*/ nodes_ptr,
+ /*.grads =*/ grads_ptr,
+ /*.leafs =*/ leafs_ptr,
+ /*.hash_table =*/ { hash_size, hash_keys_ptr },
+ /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
+ };
+
+ return cgraph;
+}
+
+struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
+ return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false);
+}
+
+struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) {
+ struct ggml_cgraph cgraph = {
+ /*.size =*/ 0,
+ /*.n_nodes =*/ i1 - i0,
+ /*.n_leafs =*/ 0,
+ /*.nodes =*/ cgraph0->nodes + i0,
+ /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL,
+ /*.leafs =*/ NULL,
+ /*.hash_table =*/ { 0, NULL },
+ /*.order =*/ cgraph0->order,
+ };
+
+ return cgraph;
+}
+
+void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
+ GGML_ASSERT(dst->size >= src->n_leafs);
+ GGML_ASSERT(dst->size >= src->n_nodes);
+ GGML_ASSERT(dst->visited_hash_table.size >= src->visited_hash_table.size);
+
+ dst->n_leafs = src->n_leafs;
+ dst->n_nodes = src->n_nodes;
+ dst->order = src->order;
+
+ for (int i = 0; i < src->n_leafs; ++i) {
+ dst->leafs[i] = src->leafs[i];
+ }
+
+ for (int i = 0; i < src->n_nodes; ++i) {
+ dst->nodes[i] = src->nodes[i];
+ }
+
+ if (src->grads) {
+ GGML_ASSERT(dst->grads != NULL);
+ for (int i = 0; i < src->n_nodes; ++i) {
+ dst->grads[i] = src->grads[i];
+ }
+ }
+
+ for (size_t i = 0; i < src->visited_hash_table.size; ++i) {
+ if (src->visited_hash_table.keys[i]) {
+ ggml_hash_insert(dst->visited_hash_table, src->visited_hash_table.keys[i]);
+ }
+ }
+}
+
+struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
+ struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL);
+ ggml_graph_cpy(cgraph, result);
+ return result;
+}
+
+void ggml_graph_reset(struct ggml_cgraph * cgraph) {
+ GGML_ASSERT(cgraph->grads != NULL);
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ struct ggml_tensor * grad = cgraph->grads[i];
+
+ if (grad) {
+ ggml_set_zero(grad);
+ }
+ }
+}
+
+void ggml_graph_clear(struct ggml_cgraph * cgraph) {
+ cgraph->n_leafs = 0;
+ cgraph->n_nodes = 0;
+ memset(cgraph->visited_hash_table.keys, 0, cgraph->visited_hash_table.size * sizeof(struct ggml_tensor *));
+}
+
+//
+// thread data
+//
+// synchronization is done via busy loops
+// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops
+//
+
+#ifdef __APPLE__
+
+//#include <os/lock.h>
+//
+//typedef os_unfair_lock ggml_lock_t;
+//
+//#define ggml_lock_init(x) UNUSED(x)
+//#define ggml_lock_destroy(x) UNUSED(x)
+//#define ggml_lock_lock os_unfair_lock_lock
+//#define ggml_lock_unlock os_unfair_lock_unlock
+//
+//#define GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT
+
+typedef int ggml_lock_t;
+
+#define ggml_lock_init(x) UNUSED(x)
+#define ggml_lock_destroy(x) UNUSED(x)
+#define ggml_lock_lock(x) UNUSED(x)
+#define ggml_lock_unlock(x) UNUSED(x)
+
+#define GGML_LOCK_INITIALIZER 0
+
+#define ggml_thread_create pthread_create
+#define ggml_thread_join pthread_join
+
+#else
+
+//typedef pthread_spinlock_t ggml_lock_t;
+
+//#define ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE)
+//#define ggml_lock_destroy pthread_spin_destroy
+//#define ggml_lock_lock pthread_spin_lock
+//#define ggml_lock_unlock pthread_spin_unlock
+
+typedef int ggml_lock_t;
+
+#define ggml_lock_init(x) UNUSED(x)
+#define ggml_lock_destroy(x) UNUSED(x)
+#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))
+#define ggml_lock_lock(x) _mm_pause()
+#else
+#define ggml_lock_lock(x) UNUSED(x)
+#endif
+#define ggml_lock_unlock(x) UNUSED(x)
+
+#define GGML_LOCK_INITIALIZER 0
+
+#define ggml_thread_create pthread_create
+#define ggml_thread_join pthread_join
+
+#endif
+
+// Android's libc implementation "bionic" does not support setting affinity
+#if defined(__gnu_linux__)
+static void set_numa_thread_affinity(int thread_n) {
+ if (!ggml_is_numa()) {
+ return;
+ }
+
+ int node_num;
+ int rv;
+ size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
+
+ switch(g_state.numa.numa_strategy) {
+ case GGML_NUMA_STRATEGY_DISTRIBUTE:
+ // run thread on node_num thread_n / (threads per node)
+ node_num = thread_n % g_state.numa.n_nodes;
+ break;
+ case GGML_NUMA_STRATEGY_ISOLATE:
+ // run thread on current_node
+ node_num = g_state.numa.current_node;
+ break;
+ case GGML_NUMA_STRATEGY_NUMACTL:
+ // use the cpuset that numactl gave us
+ rv = pthread_setaffinity_np(pthread_self(), setsize, &g_state.numa.cpuset);
+ if (rv) {
+ fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",strerror(rv));
+ }
+ return;
+ default:
+ return;
+ }
+
+ struct ggml_numa_node * node = &g_state.numa.nodes[node_num];
+
+ cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
+ CPU_ZERO_S(setsize, cpus);
+ for (size_t i = 0; i < node->n_cpus; ++i) {
+ CPU_SET_S(node->cpus[i], setsize, cpus);
+ }
+
+ rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
+ if (rv) {
+ fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv));
+ }
+
+ CPU_FREE(cpus);
+}
+
+static void clear_numa_thread_affinity(void) {
+ if (!ggml_is_numa()) {
+ return;
+ }
+
+ size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
+
+ cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
+ CPU_ZERO_S(setsize, cpus);
+ for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) {
+ CPU_SET_S(i, setsize, cpus);
+ }
+
+ int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
+ if (rv) {
+ fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv));
+ }
+
+ CPU_FREE(cpus);
+}
+#else
+// TODO: Windows etc.
+// (the linux implementation may also work on BSD, someone should test)
+static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); }
+static void clear_numa_thread_affinity(void) {}
+#endif
+
+static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
+ int n_tasks = 0;
+
+ if (ggml_is_empty(node)) {
+ // no need to multi-thread a no-op
+ n_tasks = 1;
+ return n_tasks;
+ }
+
+ switch (node->op) {
+ case GGML_OP_CPY:
+ case GGML_OP_DUP:
+ case GGML_OP_CONT:
+ case GGML_OP_ADD:
+ case GGML_OP_ADD1:
+ case GGML_OP_ACC:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_SUB:
+ case GGML_OP_SQR:
+ case GGML_OP_SQRT:
+ case GGML_OP_LOG:
+ case GGML_OP_SUM:
+ case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
+ case GGML_OP_ARGMAX:
+ case GGML_OP_REPEAT:
+ case GGML_OP_REPEAT_BACK:
+ case GGML_OP_LEAKY_RELU:
+ {
+ n_tasks = 1;
+ } break;
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(node)) {
+ case GGML_UNARY_OP_ABS:
+ case GGML_UNARY_OP_SGN:
+ case GGML_UNARY_OP_NEG:
+ case GGML_UNARY_OP_STEP:
+ case GGML_UNARY_OP_TANH:
+ case GGML_UNARY_OP_ELU:
+ case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_SIGMOID:
+ case GGML_UNARY_OP_HARDSWISH:
+ case GGML_UNARY_OP_HARDSIGMOID:
+ {
+ n_tasks = 1;
+ } break;
+
+ case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_GELU_QUICK:
+ case GGML_UNARY_OP_SILU:
+ {
+ n_tasks = n_threads;
+ } break;
+ default:
+ GGML_ASSERT(false);
+ }
+ break;
+ case GGML_OP_SILU_BACK:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_NORM:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_RMS_NORM_BACK:
+ case GGML_OP_GROUP_NORM:
+ case GGML_OP_CONCAT:
+ case GGML_OP_MUL_MAT:
+ case GGML_OP_MUL_MAT_ID:
+ case GGML_OP_OUT_PROD:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_GET_ROWS:
+ {
+ // FIXME: get_rows can use additional threads, but the cost of launching additional threads
+ // decreases performance with GPU offloading
+ //n_tasks = n_threads;
+ n_tasks = 1;
+ } break;
+ case GGML_OP_SET:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_GET_ROWS_BACK:
+ case GGML_OP_DIAG:
+ {
+ n_tasks = 1;
+ } break;
+ case GGML_OP_DIAG_MASK_ZERO:
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_SOFT_MAX_BACK:
+ case GGML_OP_ROPE:
+ case GGML_OP_ROPE_BACK:
+ case GGML_OP_ADD_REL_POS:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_CLAMP:
+ {
+ n_tasks = 1; //TODO
+ } break;
+ case GGML_OP_SCALE:
+ case GGML_OP_SOFT_MAX:
+ {
+ n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
+ } break;
+ case GGML_OP_IM2COL:
+ case GGML_OP_CONV_TRANSPOSE_1D:
+ case GGML_OP_CONV_TRANSPOSE_2D:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_POOL_1D:
+ case GGML_OP_POOL_2D:
+ {
+ n_tasks = 1;
+ } break;
+ case GGML_OP_UPSCALE:
+ case GGML_OP_PAD:
+ case GGML_OP_ARANGE:
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ case GGML_OP_ARGSORT:
+ case GGML_OP_FLASH_ATTN_EXT:
+ case GGML_OP_FLASH_ATTN_BACK:
+ case GGML_OP_SSM_CONV:
+ case GGML_OP_SSM_SCAN:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_WIN_PART:
+ case GGML_OP_WIN_UNPART:
+ case GGML_OP_GET_REL_POS:
+ case GGML_OP_MAP_UNARY:
+ case GGML_OP_MAP_BINARY:
+ case GGML_OP_MAP_CUSTOM1_F32:
+ case GGML_OP_MAP_CUSTOM2_F32:
+ case GGML_OP_MAP_CUSTOM3_F32:
+ {
+ n_tasks = 1;
+ } break;
+ case GGML_OP_MAP_CUSTOM1:
+ {
+ struct ggml_map_custom1_op_params p;
+ memcpy(&p, node->op_params, sizeof(p));
+ if (p.n_tasks == GGML_N_TASKS_MAX) {
+ n_tasks = n_threads;
+ } else {
+ n_tasks = MIN(p.n_tasks, n_threads);
+ }
+ } break;
+ case GGML_OP_MAP_CUSTOM2:
+ {
+ struct ggml_map_custom2_op_params p;
+ memcpy(&p, node->op_params, sizeof(p));
+ if (p.n_tasks == GGML_N_TASKS_MAX) {
+ n_tasks = n_threads;
+ } else {
+ n_tasks = MIN(p.n_tasks, n_threads);
+ }
+ } break;
+ case GGML_OP_MAP_CUSTOM3:
+ {
+ struct ggml_map_custom3_op_params p;
+ memcpy(&p, node->op_params, sizeof(p));
+ if (p.n_tasks == GGML_N_TASKS_MAX) {
+ n_tasks = n_threads;
+ } else {
+ n_tasks = MIN(p.n_tasks, n_threads);
+ }
+ } break;
+ case GGML_OP_CROSS_ENTROPY_LOSS:
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_NONE:
+ {
+ n_tasks = 1;
+ } break;
+ case GGML_OP_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ default:
+ {
+ fprintf(stderr, "%s: op not implemented: ", __func__);
+ if (node->op < GGML_OP_COUNT) {
+ fprintf(stderr, "%s\n", ggml_op_name(node->op));
+ } else {
+ fprintf(stderr, "%d\n", node->op);
+ }
+ GGML_ASSERT(false);
+ } break;
+ }
+
+ assert(n_tasks > 0);
+
+ return n_tasks;
+}
+
+struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threads) {
+ if (n_threads <= 0) {
+ n_threads = GGML_DEFAULT_N_THREADS;
+ }
+
+ size_t work_size = 0;
+
+ struct ggml_cplan cplan;
+ memset(&cplan, 0, sizeof(struct ggml_cplan));
+
+ int max_tasks = 1;
+
+ // thread scheduling for the different operations + work buffer size estimation
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ struct ggml_tensor * node = cgraph->nodes[i];
+
+ const int n_tasks = ggml_get_n_tasks(node, n_threads);
+
+ max_tasks = MAX(max_tasks, n_tasks);
+
+ size_t cur = 0;
+
+ switch (node->op) {
+ case GGML_OP_CPY:
+ case GGML_OP_DUP:
+ {
+ if (ggml_is_quantized(node->type) ||
+ // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
+ (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
+ (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
+ cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
+ }
+ } break;
+ case GGML_OP_ADD:
+ case GGML_OP_ADD1:
+ {
+ if (ggml_is_quantized(node->src[0]->type)) {
+ cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
+ }
+ } break;
+ case GGML_OP_ACC:
+ {
+ if (ggml_is_quantized(node->src[0]->type)) {
+ cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
+ }
+ } break;
+ case GGML_OP_MUL_MAT:
+ {
+ const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;
+
+ if (node->src[1]->type != vec_dot_type) {
+ cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
+ }
+ } break;
+ case GGML_OP_MUL_MAT_ID:
+ {
+ cur = 0;
+ const struct ggml_tensor * src0 = node->src[0];
+ const struct ggml_tensor * src1 = node->src[1];
+ const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
+ if (src1->type != vec_dot_type) {
+ cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
+ }
+ const int n_as = src0->ne[2];
+ cur += GGML_PAD(cur, sizeof(int64_t)); // align
+ cur += n_as * sizeof(int64_t); // matrix_row_counts
+ cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
+ } break;
+ case GGML_OP_OUT_PROD:
+ {
+ if (ggml_is_quantized(node->src[0]->type)) {
+ cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
+ }
+ } break;
+ case GGML_OP_SOFT_MAX:
+ case GGML_OP_ROPE:
+ {
+ cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
+ } break;
+ case GGML_OP_CONV_TRANSPOSE_1D:
+ {
+ GGML_ASSERT(node->src[0]->ne[3] == 1);
+ GGML_ASSERT(node->src[1]->ne[2] == 1);
+ GGML_ASSERT(node->src[1]->ne[3] == 1);
+
+ const int64_t ne00 = node->src[0]->ne[0]; // K
+ const int64_t ne01 = node->src[0]->ne[1]; // Cout
+ const int64_t ne02 = node->src[0]->ne[2]; // Cin
+
+ const int64_t ne10 = node->src[1]->ne[0]; // L
+ const int64_t ne11 = node->src[1]->ne[1]; // Cin
+
+ if ((node->src[0]->type == GGML_TYPE_F16 ||
+ node->src[0]->type == GGML_TYPE_BF16) &&
+ node->src[1]->type == GGML_TYPE_F32) {
+ cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
+ cur += sizeof(ggml_fp16_t)*ne10*ne11;
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
+ node->src[1]->type == GGML_TYPE_F32) {
+ cur += sizeof(float)*ne00*ne01*ne02;
+ cur += sizeof(float)*ne10*ne11;
+ } else {
+ GGML_ASSERT(false);
+ }
+ } break;
+ case GGML_OP_CONV_TRANSPOSE_2D:
+ {
+ const int64_t ne00 = node->src[0]->ne[0]; // W
+ const int64_t ne01 = node->src[0]->ne[1]; // H
+ const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
+ const int64_t ne03 = node->src[0]->ne[3]; // Channels In
+
+ const int64_t ne10 = node->src[1]->ne[0]; // W
+ const int64_t ne11 = node->src[1]->ne[1]; // H
+ const int64_t ne12 = node->src[1]->ne[2]; // Channels In
+
+ cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
+ cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
+ } break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ {
+ const int64_t ne00 = node->src[0]->ne[0]; // D
+
+ cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
+ } break;
+ case GGML_OP_FLASH_ATTN_BACK:
+ {
+ const int64_t D = node->src[0]->ne[0];
+ const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
+ const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
+ if (node->src[1]->type == GGML_TYPE_F32) {
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
+ } else if (node->src[1]->type == GGML_TYPE_F16) {
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
+ }
+ } break;
+
+ case GGML_OP_CROSS_ENTROPY_LOSS:
+ {
+ cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
+ } break;
+ case GGML_OP_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ default:
+ break;
+ }
+
+ work_size = MAX(work_size, cur);
+ }
+
+ if (work_size > 0) {
+ work_size += CACHE_LINE_SIZE*(n_threads - 1);
+ }
+
+ cplan.n_threads = MIN(max_tasks, n_threads);
+ cplan.work_size = work_size;
+ cplan.work_data = NULL;
+
+ return cplan;
+}
+
+static thread_ret_t ggml_graph_compute_thread(void * data) {
+ struct ggml_compute_state * state = (struct ggml_compute_state *) data;
+
+ const struct ggml_cgraph * cgraph = state->shared->cgraph;
+ const struct ggml_cplan * cplan = state->shared->cplan;
+
+ set_numa_thread_affinity(state->ith);
+
+ struct ggml_compute_params params = {
+ /*.ith =*/ state->ith,
+ /*.nth =*/ state->shared->n_threads,
+ /*.wsize =*/ cplan->work_size,
+ /*.wdata =*/ cplan->work_data,
+ /*.shared=*/ state->shared,
+ };
+
+ for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
+ struct ggml_tensor * node = cgraph->nodes[node_n];
+
+ ggml_compute_forward(&params, node);
+
+ if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
+ state->shared->ec = GGML_STATUS_ABORTED;
+ }
+
+ ggml_barrier(state->shared);
+
+ if (state->shared->ec != GGML_STATUS_SUCCESS) {
+ break;
+ }
+ }
+
+ return 0;
+}
+
+enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
+ GGML_ASSERT(cplan);
+ GGML_ASSERT(cplan->n_threads > 0);
+ GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL);
+
+ int n_threads = cplan->n_threads;
+
+ struct ggml_compute_state_shared state_shared = {
+ /*.cgraph =*/ cgraph,
+ /*.cgraph_plan =*/ cplan,
+ /*.n_threads =*/ n_threads,
+ /*.n_barrier =*/ 0,
+ /*.n_barrier_passed =*/ 0,
+ /*.abort_callback =*/ NULL,
+ /*.abort_callback_data =*/ NULL,
+ /*.current_chunk =*/ 0,
+ /*.ec =*/ GGML_STATUS_SUCCESS,
+ };
+
+#ifdef GGML_USE_OPENMP
+ if (n_threads > 1) {
+ #pragma omp parallel num_threads(n_threads)
+ {
+ #pragma omp single
+ {
+ // update the number of threads from the actual number of threads that we got from OpenMP
+ n_threads = omp_get_num_threads();
+ state_shared.n_threads = n_threads;
+ }
+
+ struct ggml_compute_state worker = {
+ .thrd = 0,
+ .ith = omp_get_thread_num(),
+ .shared = &state_shared,
+ };
+ ggml_graph_compute_thread(&worker);
+ }
+ } else {
+ struct ggml_compute_state worker = {
+ .thrd = 0,
+ .ith = 0,
+ .shared = &state_shared,
+ };
+ ggml_graph_compute_thread(&worker);
+ }
+#else
+ struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
+
+ for (int j = 0; j < n_threads; ++j) {
+ workers[j] = (struct ggml_compute_state) {
+ .thrd = 0,
+ .ith = j,
+ .shared = &state_shared,
+ };
+ }
+
+ // create thread pool
+ for (int j = 1; j < n_threads; ++j) {
+ const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
+ GGML_ASSERT(rc == 0);
+ UNUSED(rc);
+ }
+
+ // this is a work thread too
+ ggml_graph_compute_thread(&workers[0]);
+
+ // join or kill thread pool
+ if (n_threads > 1) {
+ for (int j = 1; j < n_threads; j++) {
+ const int rc = ggml_thread_join(workers[j].thrd, NULL);
+ GGML_ASSERT(rc == 0);
+ UNUSED(rc);
+ }
+ }
+#endif
+
+ // don't leave affinity set on the main thread
+ clear_numa_thread_affinity();
+
+ return state_shared.ec;
+}
+
+enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
+ struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
+
+ struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size);
+
+ cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
+
+ return ggml_graph_compute(cgraph, &cplan);
+}
+
+struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) {
+ for (int i = 0; i < cgraph->n_leafs; i++) {
+ struct ggml_tensor * leaf = cgraph->leafs[i];
+
+ if (strcmp(leaf->name, name) == 0) {
+ return leaf;
+ }
+ }
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ struct ggml_tensor * node = cgraph->nodes[i];
+
+ if (strcmp(node->name, name) == 0) {
+ return node;
+ }
+ }
+
+ return NULL;
+}
+
+static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fout) {
+ const int64_t * ne = tensor->ne;
+ const size_t * nb = tensor->nb;
+
+ fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
+ ggml_type_name(tensor->type),
+ ggml_op_name (tensor->op),
+ ggml_n_dims(tensor),
+ ne[0], ne[1], ne[2], ne[3],
+ nb[0], nb[1], nb[2], nb[3],
+ tensor->data,
+ tensor->name);
+}
+
+static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char * arg, FILE * fout) {
+ const int64_t * ne = tensor->ne;
+ const size_t * nb = tensor->nb;
+
+ fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
+ arg,
+ ggml_type_name(tensor->type),
+ ggml_op_name (tensor->op),
+ ggml_n_dims(tensor),
+ ne[0], ne[1], ne[2], ne[3],
+ nb[0], nb[1], nb[2], nb[3],
+ tensor->data,
+ tensor->name);
+}
+
+void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
+ uint64_t size_eval = 0;
+
+ // compute size of intermediate results
+ // TODO: does not take into account scratch buffers !!!!
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
+ size_eval += ggml_nbytes_pad(cgraph->nodes[i]);
+ }
+
+ // print
+ {
+ FILE * fout = stdout;
+
+ fprintf(fout, "\n");
+ fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC);
+ fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION);
+ fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs);
+ fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes);
+ fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval);
+
+ // header
+ fprintf(fout, "\n");
+ fprintf(fout, "%-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %16s %16s\n",
+ "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "DATA", "NAME");
+
+ for (int i = 0; i < cgraph->n_leafs; ++i) {
+ ggml_graph_export_leaf(cgraph->leafs[i], fout);
+
+ GGML_ASSERT(cgraph->leafs[i]->op == GGML_OP_NONE);
+ GGML_ASSERT(cgraph->leafs[i]->src[0] == NULL);
+ GGML_ASSERT(cgraph->leafs[i]->src[1] == NULL);
+ }
+
+ // header
+ fprintf(fout, "\n");
+ fprintf(fout, "%-6s %-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %8s %16s %16s\n",
+ "ARG", "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "NTASKS", "DATA", "NAME");
+
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
+ ggml_graph_export_node(cgraph->nodes[i], "DST", fout);
+
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
+ if (cgraph->nodes[i]->src[j]) {
+ ggml_graph_export_node(cgraph->nodes[i]->src[j], "SRC", fout);
+ }
+ }
+
+ fprintf(fout, "\n");
+ }
+
+ fprintf(fout, "\n");
+ }
+
+ // write binary data
+ {
+ FILE * fout = ggml_fopen(fname, "wb");
+
+ if (!fout) {
+ fprintf(stderr, "%s: failed to open %s: %s\n", __func__, fname, strerror(errno));
+ return;
+ }
+
+ // header
+ {
+ const uint32_t magic = GGML_FILE_MAGIC;
+ const uint32_t version = GGML_FILE_VERSION;
+ const uint32_t n_leafs = cgraph->n_leafs;
+ const uint32_t n_nodes = cgraph->n_nodes;
+
+ fwrite(&magic, sizeof(uint32_t), 1, fout);
+ fwrite(&version, sizeof(uint32_t), 1, fout);
+ fwrite(&n_leafs, sizeof(uint32_t), 1, fout);
+ fwrite(&n_nodes, sizeof(uint32_t), 1, fout);
+ fwrite(&size_eval, sizeof(uint64_t), 1, fout);
+ }
+
+ // leafs
+ {
+ for (int i = 0; i < cgraph->n_leafs; ++i) {
+ const struct ggml_tensor * tensor = cgraph->leafs[i];
+
+ const uint32_t type = tensor->type;
+ const uint32_t op = tensor->op;
+
+ fwrite(&type, sizeof(uint32_t), 1, fout);
+ fwrite(&op, sizeof(uint32_t), 1, fout);
+
+ for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+ const uint64_t ne = tensor->ne[j];
+ const uint64_t nb = tensor->nb[j];
+
+ fwrite(&ne, sizeof(uint64_t), 1, fout);
+ fwrite(&nb, sizeof(uint64_t), 1, fout);
+ }
+
+ fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout);
+ fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout);
+
+ // dump the data
+ // TODO: pad this to 32 byte boundary
+ {
+ const size_t size = ggml_nbytes(tensor);
+
+ fwrite(tensor->data, sizeof(char), size, fout);
+ }
+ }
+ }
+
+ // nodes
+ {
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
+ const struct ggml_tensor * tensor = cgraph->nodes[i];
+
+ const uint32_t type = tensor->type;
+ const uint32_t op = tensor->op;
+
+ fwrite(&type, sizeof(uint32_t), 1, fout);
+ fwrite(&op, sizeof(uint32_t), 1, fout);
+
+ for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+ const uint64_t ne = tensor->ne[j];
+ const uint64_t nb = tensor->nb[j];
+
+ fwrite(&ne, sizeof(uint64_t), 1, fout);
+ fwrite(&nb, sizeof(uint64_t), 1, fout);
+ }
+
+ fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout);
+ fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout);
+
+ // output the op arguments
+ {
+ struct ggml_tensor * args[GGML_MAX_SRC] = { NULL };
+
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
+ args[j] = tensor->src[j];
+ }
+
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
+ if (args[j]) {
+ int32_t idx = -1;
+
+ // check if leaf
+ {
+ for (int k = 0; k < cgraph->n_leafs; ++k) {
+ if (args[j] == cgraph->leafs[k]) {
+ idx = k;
+ break;
+ }
+ }
+ }
+
+ // check if node
+ if (idx == -1) {
+ for (int k = 0; k < cgraph->n_nodes; ++k) {
+ if (args[j] == cgraph->nodes[k]) {
+ idx = cgraph->n_leafs + k;
+ break;
+ }
+ }
+ }
+
+ if (idx == -1) {
+ fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i);
+ fclose(fout);
+ return;
+ }
+
+ fwrite(&idx, sizeof(int32_t), 1, fout);
+ } else {
+ const int32_t nul = -1;
+
+ fwrite(&nul, sizeof(int32_t), 1, fout);
+ }
+ }
+ }
+ }
+ }
+
+ fclose(fout);
+ }
+}
+
+struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval) {
+ assert(*ctx_data == NULL);
+ assert(*ctx_eval == NULL);
+
+ struct ggml_cgraph * result = NULL;
+
+ struct ggml_tensor * data = NULL;
+
+ // read file into data
+ {
+ FILE * fin = ggml_fopen(fname, "rb");
+ if (!fin) {
+ fprintf(stderr, "%s: failed to open %s: %s\n", __func__, fname, strerror(errno));
+ return result;
+ }
+
+ size_t fsize = 0;
+
+ fseek(fin, 0, SEEK_END);
+ fsize = ftell(fin);
+ fseek(fin, 0, SEEK_SET);
+
+ // create the data context
+ {
+ const size_t overhead = 1*ggml_tensor_overhead();
+
+ struct ggml_init_params params = {
+ .mem_size = fsize + overhead,
+ .mem_buffer = NULL,
+ .no_alloc = false,
+ };
+
+ *ctx_data = ggml_init(params);
+
+ if (!*ctx_data) {
+ fprintf(stderr, "%s: failed to create ggml context\n", __func__);
+ fclose(fin);
+ return result;
+ }
+ }
+
+ data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize);
+
+ {
+ const size_t ret = fread(data->data, sizeof(char), fsize, fin);
+ if (ret != fsize) {
+ fprintf(stderr, "%s: failed to read %s\n", __func__, fname);
+ fclose(fin);
+ return result;
+ }
+ }
+
+ fclose(fin);
+ }
+
+ // populate result
+ {
+ char * ptr = (char *) data->data;
+
+ const uint32_t magic = *(const uint32_t *) ptr; ptr += sizeof(magic);
+
+ if (magic != GGML_FILE_MAGIC) {
+ fprintf(stderr, "%s: invalid magic number, got %08x\n", __func__, magic);
+ return result;
+ }
+
+ const uint32_t version = *(const uint32_t *) ptr; ptr += sizeof(version);
+
+ if (version != GGML_FILE_VERSION) {
+ fprintf(stderr, "%s: invalid version number\n", __func__);
+ return result;
+ }
+
+ const uint32_t n_leafs = *(const uint32_t *) ptr; ptr += sizeof(n_leafs);
+ const uint32_t n_nodes = *(const uint32_t *) ptr; ptr += sizeof(n_nodes);
+ const uint64_t size_eval = *(const uint64_t *) ptr; ptr += sizeof(size_eval);
+ const int graph_size = MAX(n_leafs, n_nodes);
+
+ // create the data context
+ {
+ const size_t overhead = (n_leafs + n_nodes)*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph_size, false);
+
+ struct ggml_init_params params = {
+ .mem_size = size_eval + overhead,
+ .mem_buffer = NULL,
+ .no_alloc = true,
+ };
+
+ *ctx_eval = ggml_init(params);
+
+ if (!*ctx_eval) {
+ fprintf(stderr, "%s: failed to create ggml context\n", __func__);
+ return result;
+ }
+ }
+
+ result = ggml_new_graph_custom(*ctx_eval, graph_size, false);
+
+ result->n_leafs = n_leafs;
+ result->n_nodes = n_nodes;
+
+
+ // leafs
+ {
+ uint32_t type;
+ uint32_t op;
+
+ for (uint32_t i = 0; i < n_leafs; ++i) {
+ type = *(const uint32_t *) ptr; ptr += sizeof(type);
+ op = *(const uint32_t *) ptr; ptr += sizeof(op);
+
+ int64_t ne[GGML_MAX_DIMS];
+ size_t nb[GGML_MAX_DIMS];
+
+ for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+ uint64_t ne_cur;
+ uint64_t nb_cur;
+
+ ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur);
+ nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur);
+
+ ne[j] = ne_cur;
+ nb[j] = nb_cur;
+ }
+
+ struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, GGML_MAX_DIMS, ne);
+
+ tensor->op = (enum ggml_op) op;
+
+ memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME;
+ memcpy(tensor->op_params, ptr, GGML_MAX_OP_PARAMS); ptr += GGML_MAX_OP_PARAMS;
+
+ tensor->data = (void *) ptr;
+
+ for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+ tensor->nb[j] = nb[j];
+ }
+
+ result->leafs[i] = tensor;
+
+ ptr += ggml_nbytes(tensor);
+
+ fprintf(stderr, "%s: loaded leaf %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor));
+ }
+ }
+
+ ggml_set_no_alloc(*ctx_eval, false);
+
+ // nodes
+ {
+ uint32_t type;
+ uint32_t op;
+
+ for (uint32_t i = 0; i < n_nodes; ++i) {
+ type = *(const uint32_t *) ptr; ptr += sizeof(type);
+ op = *(const uint32_t *) ptr; ptr += sizeof(op);
+
+ enum ggml_op eop = (enum ggml_op) op;
+
+ int64_t ne[GGML_MAX_DIMS];
+ size_t nb[GGML_MAX_DIMS];
+
+ for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+ uint64_t ne_cur;
+ uint64_t nb_cur;
+
+ ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur);
+ nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur);
+
+ ne[j] = ne_cur;
+ nb[j] = nb_cur;
+ }
+
+ const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
+ const char * ptr_op_params = ptr; ptr += GGML_MAX_OP_PARAMS;
+
+ const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += GGML_MAX_SRC*sizeof(int32_t);
+
+ struct ggml_tensor * args[GGML_MAX_SRC] = { NULL };
+
+ // parse args
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
+ const int32_t arg_idx = ptr_arg_idx[j];
+
+ if (arg_idx == -1) {
+ continue;
+ }
+
+ if (arg_idx < result->n_leafs) {
+ args[j] = result->leafs[arg_idx];
+ } else {
+ args[j] = result->nodes[arg_idx - result->n_leafs];
+ }
+ }
+
+ // create the tensor
+ // "view" operations are handled differently
+ // TODO: handle inplace ops - currently a copy is always made
+
+ struct ggml_tensor * tensor = NULL;
+
+ switch (eop) {
+ // TODO: implement other view ops
+ case GGML_OP_RESHAPE:
+ {
+ tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]);
+ } break;
+ case GGML_OP_VIEW:
+ {
+ tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
+
+ size_t offs;
+ memcpy(&offs, ptr_op_params, sizeof(offs));
+
+ tensor->data = ((char *) tensor->data) + offs;
+ } break;
+ case GGML_OP_TRANSPOSE:
+ {
+ tensor = ggml_transpose(*ctx_eval, args[0]);
+ } break;
+ case GGML_OP_PERMUTE:
+ {
+ tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
+ } break;
+ default:
+ {
+ tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, GGML_MAX_DIMS, ne);
+
+ tensor->op = eop;
+ } break;
+ }
+
+ memcpy(tensor->name, ptr_name, GGML_MAX_NAME);
+ memcpy(tensor->op_params, ptr_op_params, GGML_MAX_OP_PARAMS);
+
+ for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+ tensor->nb[j] = nb[j];
+ }
+
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
+ tensor->src[j] = args[j];
+ }
+
+ result->nodes[i] = tensor;
+
+ fprintf(stderr, "%s: loaded node %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor));
+ }
+ }
+ }
+
+ return result;
+}
+
+void ggml_graph_print(const struct ggml_cgraph * cgraph) {
+ GGML_PRINT("=== GRAPH ===\n");
+
+ GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ struct ggml_tensor * node = cgraph->nodes[i];
+
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n",
+ i,
+ node->ne[0], node->ne[1], node->ne[2],
+ ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " ");
+ }
+
+ GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs);
+ for (int i = 0; i < cgraph->n_leafs; i++) {
+ struct ggml_tensor * node = cgraph->leafs[i];
+
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
+ i,
+ node->ne[0], node->ne[1],
+ ggml_op_name(node->op),
+ ggml_get_name(node));
+ }
+
+ GGML_PRINT("========================================\n");
+}
+
+// check if node is part of the graph
+static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+ if (cgraph == NULL) {
+ return true;
+ }
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ if (cgraph->nodes[i] == node) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ struct ggml_tensor * parent = cgraph->nodes[i];
+
+ if (parent->grad == node) {
+ return parent;
+ }
+ }
+
+ return NULL;
+}
+
+static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) {
+ struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node);
+ struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent);
+ fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n",
+ gparent0 ? (void *) gparent0 : (void *) parent,
+ gparent0 ? "g" : "x",
+ gparent ? (void *) gparent : (void *) node,
+ gparent ? "g" : "x",
+ gparent ? "empty" : "vee",
+ gparent ? "dashed" : "solid",
+ label);
+}
+
+static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) {
+ fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n",
+ (void *) parent, "x",
+ (void *) node, "x",
+ label);
+}
+
+void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) {
+ char color[16];
+
+ FILE * fp = ggml_fopen(filename, "w");
+ GGML_ASSERT(fp);
+
+ fprintf(fp, "digraph G {\n");
+ fprintf(fp, " newrank = true;\n");
+ fprintf(fp, " rankdir = TB;\n");
+
+ for (int i = 0; i < gb->n_nodes; i++) {
+ struct ggml_tensor * node = gb->nodes[i];
+
+ if (ggml_graph_get_parent(gb, node) != NULL) {
+ continue;
+ }
+
+ if (node->flags & GGML_TENSOR_FLAG_PARAM) {
+ snprintf(color, sizeof(color), "yellow");
+ } else if (node->grad) {
+ if (ggml_graph_find(gf, node)) {
+ snprintf(color, sizeof(color), "green");
+ } else {
+ snprintf(color, sizeof(color), "lightblue");
+ }
+ } else {
+ snprintf(color, sizeof(color), "white");
+ }
+
+ fprintf(fp, " \"%p\" [ "
+ "style = filled; fillcolor = %s; shape = record; "
+ "label=\"",
+ (void *) node, color);
+
+ if (strlen(node->name) > 0) {
+ fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type));
+ } else {
+ fprintf(fp, "(%s)|", ggml_type_name(node->type));
+ }
+
+ if (ggml_is_matrix(node)) {
+ fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], ggml_op_symbol(node->op));
+ } else {
+ fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op));
+ }
+
+ if (node->grad) {
+ fprintf(fp, " | <g>%s\"; ]\n", ggml_op_symbol(node->grad->op));
+ } else {
+ fprintf(fp, "\"; ]\n");
+ }
+ }
+
+ for (int i = 0; i < gb->n_leafs; i++) {
+ struct ggml_tensor * node = gb->leafs[i];
+
+ snprintf(color, sizeof(color), "pink");
+
+ fprintf(fp, " \"%p\" [ "
+ "style = filled; fillcolor = %s; shape = record; "
+ "label=\"<x>",
+ (void *) node, color);
+
+ if (strlen(node->name) > 0) {
+ fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type));
+ } else {
+ fprintf(fp, "(%s)|", ggml_type_name(node->type));
+ }
+
+ fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
+ if (ggml_nelements(node) < 5 && node->data != NULL) {
+ fprintf(fp, " | (");
+ for (int j = 0; j < ggml_nelements(node); j++) {
+ if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
+ fprintf(fp, "%d", ggml_get_i32_1d(node, j));
+ }
+ else if (node->type == GGML_TYPE_F32 ||
+ node->type == GGML_TYPE_F16 ||
+ node->type == GGML_TYPE_BF16) {
+ fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j));
+ }
+ else {
+ fprintf(fp, "#");
+ }
+ if (j < ggml_nelements(node) - 1) {
+ fprintf(fp, ", ");
+ }
+ }
+ fprintf(fp, ")");
+ }
+ fprintf(fp, "\"; ]\n");
+ }
+
+ for (int i = 0; i < gb->n_nodes; i++) {
+ struct ggml_tensor * node = gb->nodes[i];
+
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ if (node->src[j]) {
+ char label[16];
+ snprintf(label, sizeof(label), "src %d", j);
+ ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label);
+ }
+ }
+ }
+
+ for (int i = 0; i < gb->n_leafs; i++) {
+ struct ggml_tensor * node = gb->leafs[i];
+
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ if (node->src[j]) {
+ char label[16];
+ snprintf(label, sizeof(label), "src %d", j);
+ ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label);
+ }
+ }
+ }
+
+ fprintf(fp, "}\n");
+
+ fclose(fp);
+
+ GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) {
+ int i = 0;
+ for (int p = 0; p < np; ++p) {
+ const int64_t ne = ggml_nelements(ps[p]) ;
+ // TODO: add function to set tensor from array
+ for (int64_t j = 0; j < ne; ++j) {
+ ggml_set_f32_1d(ps[p], j, x[i++]);
+ }
+ }
+}
+
+static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) {
+ int i = 0;
+ for (int p = 0; p < np; ++p) {
+ const int64_t ne = ggml_nelements(ps[p]) ;
+ // TODO: add function to get all elements at once
+ for (int64_t j = 0; j < ne; ++j) {
+ x[i++] = ggml_get_f32_1d(ps[p], j);
+ }
+ }
+}
+
+static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
+ int64_t i = 0;
+ for (int p = 0; p < np; ++p) {
+ const int64_t ne = ggml_nelements(ps[p]) ;
+ // TODO: add function to get all elements at once
+ for (int64_t j = 0; j < ne; ++j) {
+ g[i++] = ggml_get_f32_1d(ps[p]->grad, j);
+ }
+ }
+}
+
+static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) {
+ int64_t i = 0;
+ for (int p = 0; p < np; ++p) {
+ const int64_t ne = ggml_nelements(ps[p]) ;
+ // TODO: add function to get all elements at once
+ for (int64_t j = 0; j < ne; ++j) {
+ g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale;
+ }
+ }
+}
+
+//
+// Using AdamW - ref: https://arxiv.org/pdf/1711.05101v3.pdf
+//
+// (Original Adam - ref: https://arxiv.org/pdf/1412.6980.pdf)
+//
+
+static enum ggml_opt_result ggml_opt_adam(
+ struct ggml_context * ctx,
+ struct ggml_opt_context * opt,
+ struct ggml_opt_params params,
+ struct ggml_tensor * f,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb,
+ ggml_opt_callback callback,
+ void * callback_data) {
+ GGML_ASSERT(ggml_is_scalar(f));
+
+ // these will store the parameters we want to optimize
+ struct ggml_tensor * ps[GGML_MAX_PARAMS];
+
+ int np = 0;
+ int64_t nx = 0;
+ for (int i = 0; i < gf->n_nodes; ++i) {
+ if (gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
+ GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
+
+ GGML_ASSERT(np < GGML_MAX_PARAMS);
+
+ ps[np++] = gf->nodes[i];
+ nx += ggml_nelements(gf->nodes[i]);
+ }
+ }
+
+ if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) {
+ int iter = opt->iter;
+ ggml_opt_init(opt->ctx, opt, params, nx);
+ opt->iter = iter;
+ }
+
+ // constants
+ float sched = params.adam.sched;
+ const float alpha = params.adam.alpha;
+ const float decay = params.adam.decay * alpha;
+ const float beta1 = params.adam.beta1;
+ const float beta2 = params.adam.beta2;
+ const float eps = params.adam.eps;
+ const float gclip = params.adam.gclip;
+ const int decay_min_ndim = params.adam.decay_min_ndim;
+ const int n_accum = MAX(1, params.n_gradient_accumulation);
+ const float accum_norm = 1.0f / (float) n_accum;
+
+ float * g = opt->adam.g->data; // gradients
+ float * m = opt->adam.m->data; // first moment
+ float * v = opt->adam.v->data; // second moment
+
+ float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
+
+ struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
+ struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size);
+ cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
+
+ bool cancel = false;
+
+ // compute the function value
+ float fx = 0;
+ ggml_set_zero(opt->adam.g);
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+ if (callback) {
+ callback(callback_data, accum_step, &sched, &cancel);
+ if (cancel) {
+ return GGML_OPT_RESULT_CANCEL;
+ }
+ }
+ // ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(gb, &cplan);
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
+ fx += ggml_get_f32_1d(f, 0);
+ }
+ fx *= accum_norm;
+
+ opt->adam.fx_prev = fx;
+ opt->adam.fx_best = opt->adam.fx_prev;
+ if (pf) {
+ pf[opt->iter % params.past] = opt->adam.fx_prev;
+ }
+
+ opt->loss_before = opt->adam.fx_prev;
+ opt->loss_after = opt->adam.fx_prev;
+
+ // initialize
+ if (opt->just_initialized) {
+ opt->adam.n_no_improvement = 0;
+ opt->just_initialized = false;
+ }
+
+ float * fx_best = &opt->adam.fx_best;
+ float * fx_prev = &opt->adam.fx_prev;
+ int * n_no_improvement = &opt->adam.n_no_improvement;
+
+ int iter0 = opt->iter;
+
+ // run the optimizer
+ for (int t = 0; t < params.adam.n_iter; ++t) {
+ opt->iter = iter0 + t + 1;
+ GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
+
+ GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0));
+ GGML_PRINT_DEBUG_5("df/dx0 = %10.6f\n", ggml_get_f32_1d(ps[0]->grad, 0));
+ GGML_PRINT_DEBUG_5("df/dx1 = %10.6f\n", ggml_get_f32_1d(ps[1]->grad, 0));
+
+ for (int i = 0; i < np; ++i) {
+ GGML_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i,
+ ggml_get_f32_1d(ps[i], 0), ggml_get_f32_1d(ps[i]->grad, 0));
+ }
+
+ const int64_t t_start_wall = ggml_time_us();
+ const int64_t t_start_cpu = ggml_cycles();
+ UNUSED(t_start_wall);
+ UNUSED(t_start_cpu);
+
+ {
+ float gnorm = 1.0f;
+ if (gclip > 0.0f) {
+ // gradient clipping
+ ggml_float sum = 0.0;
+ for (int64_t i = 0; i < nx; ++i) {
+ sum += (ggml_float)(g[i]*g[i]);
+ }
+ ggml_float norm = sqrt(sum);
+ if (norm > (ggml_float) gclip) {
+ gnorm = (float) ((ggml_float) gclip / norm);
+ }
+ }
+ const float beta1h = alpha*sched/(1.0f - powf(beta1, opt->iter));
+ const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter));
+ int64_t i = 0;
+ for (int p = 0; p < np; ++p) {
+ const int64_t ne = ggml_nelements(ps[p]);
+ const float p_decay = ((ggml_n_dims(ps[p]) >= decay_min_ndim) ? decay : 0.0f) * sched;
+ for (int64_t j = 0; j < ne; ++j) {
+ float x = ggml_get_f32_1d(ps[p], j);
+ float g_ = g[i]*gnorm;
+ m[i] = m[i]*beta1 + g_*(1.0f - beta1);
+ v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2);
+ float mh = m[i]*beta1h;
+ float vh = v[i]*beta2h;
+ vh = sqrtf(vh) + eps;
+ x = x*(1.0f - p_decay) - mh/vh;
+ ggml_set_f32_1d(ps[p], j, x);
+ ++i;
+ }
+ }
+ }
+
+ fx = 0;
+ ggml_set_zero(opt->adam.g);
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+ if (callback) {
+ callback(callback_data, accum_step, &sched, &cancel);
+ if (cancel) {
+ return GGML_OPT_RESULT_CANCEL;;
+ }
+ }
+ // ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(gb, &cplan);
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
+ fx += ggml_get_f32_1d(f, 0);
+ }
+ fx *= accum_norm;
+
+ opt->loss_after = fx;
+
+ // check convergence
+ if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
+ GGML_PRINT_DEBUG("converged\n");
+
+ return GGML_OPT_RESULT_OK;
+ }
+
+ // delta-based convergence test
+ if (pf != NULL) {
+ // need at least params.past iterations to start checking for convergence
+ if (params.past <= iter0 + t) {
+ const float rate = (pf[(iter0 + t)%params.past] - fx)/fx;
+
+ if (fabsf(rate) < params.delta) {
+ return GGML_OPT_RESULT_OK;
+ }
+ }
+
+ pf[(iter0 + t)%params.past] = fx;
+ }
+
+ // check for improvement
+ if (params.max_no_improvement > 0) {
+ if (fx_best[0] > fx) {
+ fx_best[0] = fx;
+ n_no_improvement[0] = 0;
+ } else {
+ ++n_no_improvement[0];
+
+ if (n_no_improvement[0] >= params.max_no_improvement) {
+ return GGML_OPT_RESULT_OK;
+ }
+ }
+ }
+
+ fx_prev[0] = fx;
+
+ {
+ const int64_t t_end_cpu = ggml_cycles();
+ GGML_PRINT_DEBUG("time iter: %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC);
+ UNUSED(t_end_cpu);
+
+ const int64_t t_end_wall = ggml_time_us();
+ GGML_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6);
+ UNUSED(t_end_wall);
+ }
+ }
+
+ return GGML_OPT_RESULT_DID_NOT_CONVERGE;
+}
+
+//
+// L-BFGS
+//
+// the L-BFGS implementation below is based on the following implementation:
+//
+// https://github.com/chokkan/liblbfgs
+//
+
+struct ggml_lbfgs_iteration_data {
+ float alpha;
+ float ys;
+ float * s;
+ float * y;
+};
+
+static enum ggml_opt_result linesearch_backtracking(
+ const struct ggml_opt_params * params,
+ int nx,
+ float * x,
+ float * fx,
+ float * g,
+ float * d,
+ float * step,
+ const float * xp,
+ struct ggml_tensor * f,
+ struct ggml_cgraph * gb,
+ struct ggml_cplan * cplan,
+ const int np,
+ struct ggml_tensor * ps[],
+ bool * cancel,
+ ggml_opt_callback callback,
+ void * callback_data) {
+ int count = 0;
+
+ float width = 0.0f;
+ float dg = 0.0f;
+ float finit = 0.0f;
+ float dginit = 0.0f;
+ float dgtest = 0.0f;
+
+ const float dec = 0.5f;
+ const float inc = 2.1f;
+
+ const int n_accum = MAX(1, params->n_gradient_accumulation);
+ const float accum_norm = 1.0f / (float) n_accum;
+
+ if (*step <= 0.f) {
+ return GGML_LINESEARCH_INVALID_PARAMETERS;
+ }
+
+ // compute the initial gradient in the search direction
+ ggml_vec_dot_f32(nx, &dginit, 0, g, 0, d, 0, 1);
+
+ // make sure that d points to a descent direction
+ if (0 < dginit) {
+ return GGML_LINESEARCH_FAIL;
+ }
+
+ // initialize local variables
+ finit = *fx;
+ dgtest = params->lbfgs.ftol*dginit;
+
+ while (true) {
+ ggml_vec_cpy_f32(nx, x, xp);
+ ggml_vec_mad_f32(nx, x, d, *step);
+
+ // evaluate the function and gradient values
+ {
+ ggml_opt_set_params(np, ps, x);
+
+ *fx = 0;
+ memset(g, 0, sizeof(float)*nx);
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+ if (callback) {
+ // LBFG-S does not support learning rate -> ignore learning schedule
+ float sched = 0;
+ callback(callback_data, accum_step, &sched, cancel);
+ if (*cancel) {
+ return GGML_OPT_RESULT_CANCEL;
+ }
+ }
+ // ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(gb, cplan);
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
+ *fx += ggml_get_f32_1d(f, 0);
+ }
+ *fx *= accum_norm;
+
+ }
+
+ ++count;
+
+ if (*fx > finit + (*step)*dgtest) {
+ width = dec;
+ } else {
+ // Armijo condition is satisfied
+ if (params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_ARMIJO) {
+ return count;
+ }
+
+ ggml_vec_dot_f32(nx, &dg, 0, g, 0, d, 0, 1);
+
+ // check the Wolfe condition
+ if (dg < params->lbfgs.wolfe * dginit) {
+ width = inc;
+ } else {
+ if(params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE) {
+ // regular Wolfe conditions
+ return count;
+ }
+
+ if(dg > -params->lbfgs.wolfe*dginit) {
+ width = dec;
+ } else {
+ // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE)
+ return count;
+ }
+ }
+ }
+
+ if (*step < params->lbfgs.min_step) {
+ return GGML_LINESEARCH_MINIMUM_STEP;
+ }
+ if (*step > params->lbfgs.max_step) {
+ return GGML_LINESEARCH_MAXIMUM_STEP;
+ }
+ if (params->lbfgs.max_linesearch <= count) {
+ return GGML_LINESEARCH_MAXIMUM_ITERATIONS;
+ }
+
+ (*step) *= width;
+ }
+
+ GGML_ASSERT(false && "line search failed");
+
+ return GGML_LINESEARCH_FAIL;
+}
+
+static enum ggml_opt_result ggml_opt_lbfgs(
+ struct ggml_context * ctx,
+ struct ggml_opt_context * opt,
+ struct ggml_opt_params params,
+ struct ggml_tensor * f,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb,
+ ggml_opt_callback callback,
+ void * callback_data) {
+ if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE ||
+ params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) {
+ if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) {
+ return GGML_OPT_RESULT_INVALID_WOLFE;
+ }
+ }
+
+ const int m = params.lbfgs.m;
+
+ // these will store the parameters we want to optimize
+ struct ggml_tensor * ps[GGML_MAX_PARAMS];
+
+ int np = 0;
+ int nx = 0;
+ for (int i = 0; i < gf->n_nodes; ++i) {
+ if (gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
+ GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
+
+ GGML_ASSERT(np < GGML_MAX_PARAMS);
+
+ ps[np++] = gf->nodes[i];
+ nx += ggml_nelements(gf->nodes[i]);
+ }
+ }
+
+ if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) {
+ int iter = opt->iter;
+ ggml_opt_init(ctx, opt, params, nx);
+ opt->iter = iter;
+ }
+
+ struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
+ struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size);
+ cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
+
+ float * x = opt->lbfgs.x->data; // current parameters
+ float * xp = opt->lbfgs.xp->data; // previous parameters
+ float * g = opt->lbfgs.g->data; // current gradient
+ float * gp = opt->lbfgs.gp->data; // previous gradient
+ float * d = opt->lbfgs.d->data; // search direction
+
+ float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
+
+ const int n_accum = MAX(1, params.n_gradient_accumulation);
+ const float accum_norm = 1.0f / (float) n_accum;
+
+ float fx = 0.0f; // cost function value
+ float xnorm = 0.0f; // ||x||
+ float gnorm = 0.0f; // ||g||
+
+ // initialize x from the graph nodes
+ ggml_opt_get_params(np, ps, x);
+
+ // the L-BFGS memory
+ float * lm_alpha = opt->lbfgs.lmal->data;
+ float * lm_ys = opt->lbfgs.lmys->data;
+ float * lm_s = opt->lbfgs.lms->data;
+ float * lm_y = opt->lbfgs.lmy->data;
+
+ bool cancel = false;
+
+ // evaluate the function value and its gradient
+ {
+ ggml_opt_set_params(np, ps, x);
+
+ fx = 0;
+ memset(g, 0, sizeof(float)*nx);
+ for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
+ if (callback) {
+ // LBFG-S does not support learning rate -> ignore learning schedule
+ float sched = 0;
+ callback(callback_data, accum_step, &sched, &cancel);
+ if (cancel) {
+ return GGML_OPT_RESULT_CANCEL;
+ }
+ }
+ // ggml_graph_reset (gf);
+ ggml_set_f32 (f->grad, 1.0f);
+ ggml_graph_compute(gb, &cplan);
+ ggml_opt_acc_grad(np, ps, g, accum_norm);
+ fx += ggml_get_f32_1d(f, 0);
+ }
+ fx *= accum_norm;
+
+ opt->loss_before = fx;
+ opt->loss_after = fx;
+ }
+
+ // search direction = -gradient
+ ggml_vec_neg_f32(nx, d, g);
+
+ // ||x||, ||g||
+ ggml_vec_norm_f32(nx, &xnorm, x);
+ ggml_vec_norm_f32(nx, &gnorm, g);
+
+ if (xnorm < 1.0f) {
+ xnorm = 1.0f;
+ }
+
+ // already optimized
+ if (gnorm/xnorm <= params.lbfgs.eps) {
+ return GGML_OPT_RESULT_OK;
+ }
+
+ if (opt->just_initialized) {
+ if (pf) {
+ pf[0] = fx;
+ }
+ opt->lbfgs.fx_best = fx;
+
+ // initial step
+ ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d);
+ opt->lbfgs.j = 0;
+ opt->lbfgs.k = 1;
+ opt->lbfgs.end = 0;
+ opt->lbfgs.n_no_improvement = 0;
+ opt->just_initialized = false;
+ }
+
+ float * fx_best = &opt->lbfgs.fx_best;
+ float * step = &opt->lbfgs.step;
+ int * j = &opt->lbfgs.j;
+ int * k = &opt->lbfgs.k;
+ int * end = &opt->lbfgs.end;
+ int * n_no_improvement = &opt->lbfgs.n_no_improvement;
+
+ int ls = 0;
+ int bound = 0;
+
+ float ys = 0.0f;
+ float yy = 0.0f;
+ float beta = 0.0f;
+
+ int it = 0;
+
+ while (true) {
+ // store the current position and gradient vectors
+ ggml_vec_cpy_f32(nx, xp, x);
+ ggml_vec_cpy_f32(nx, gp, g);
+
+ // TODO: instead of passing &cancel here, use the return code of the linesearch
+ // to determine if the optimization should be cancelled
+ // this is a simple change, but not doing this atm, since I don't have a nice
+ // way to test and don't want to break something with so many changes lined up
+ ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
+ if (cancel) {
+ return GGML_OPT_RESULT_CANCEL;
+ }
+
+ if (ls < 0) {
+ // linesearch failed - go back to the previous point and return
+ ggml_vec_cpy_f32(nx, x, xp);
+ ggml_vec_cpy_f32(nx, g, gp);
+
+ return ls;
+ }
+
+ opt->loss_after = fx;
+
+ ggml_vec_norm_f32(nx, &xnorm, x);
+ ggml_vec_norm_f32(nx, &gnorm, g);
+
+ GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0));
+
+ if (xnorm < 1.0f) {
+ xnorm = 1.0f;
+ }
+ if (gnorm/xnorm <= params.lbfgs.eps) {
+ // converged
+ return GGML_OPT_RESULT_OK;
+ }
+
+ // delta-based convergence test
+ if (pf != NULL) {
+ // need at least params.past iterations to start checking for convergence
+ if (params.past <= k[0]) {
+ const float rate = (pf[k[0]%params.past] - fx)/fx;
+
+ if (fabsf(rate) < params.delta) {
+ return GGML_OPT_RESULT_OK;
+ }
+ }
+
+ pf[k[0]%params.past] = fx;
+ }
+
+ // check for improvement
+ if (params.max_no_improvement > 0) {
+ if (fx < fx_best[0]) {
+ fx_best[0] = fx;
+ n_no_improvement[0] = 0;
+ } else {
+ n_no_improvement[0]++;
+
+ if (n_no_improvement[0] >= params.max_no_improvement) {
+ return GGML_OPT_RESULT_OK;
+ }
+ }
+ }
+
+ if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) {
+ // reached the maximum number of iterations
+ return GGML_OPT_RESULT_DID_NOT_CONVERGE;
+ }
+
+ // update vectors s and y:
+ // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}.
+ // y_{k+1} = g_{k+1} - g_{k}.
+ //
+ ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp);
+ ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp);
+
+ // compute scalars ys and yy:
+ // ys = y^t \cdot s -> 1 / \rho.
+ // yy = y^t \cdot y.
+ //
+ ggml_vec_dot_f32(nx, &ys, 0, &lm_y[end[0]*nx], 0, &lm_s[end[0]*nx], 0, 1);
+ ggml_vec_dot_f32(nx, &yy, 0, &lm_y[end[0]*nx], 0, &lm_y[end[0]*nx], 0, 1);
+
+ lm_ys[end[0]] = ys;
+
+ // find new search direction
+ // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS
+
+ bound = (m <= k[0]) ? m : k[0];
+ k[0]++;
+ it++;
+ end[0] = (end[0] + 1)%m;
+
+ // initialize search direction with -g
+ ggml_vec_neg_f32(nx, d, g);
+
+ j[0] = end[0];
+ for (int i = 0; i < bound; ++i) {
+ j[0] = (j[0] + m - 1) % m;
+ // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
+ ggml_vec_dot_f32(nx, &lm_alpha[j[0]], 0, &lm_s[j[0]*nx], 0, d, 0, 1);
+ lm_alpha[j[0]] /= lm_ys[j[0]];
+ // q_{i} = q_{i+1} - \alpha_{i} y_{i}
+ ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
+ }
+
+ ggml_vec_scale_f32(nx, d, ys/yy);
+
+ for (int i = 0; i < bound; ++i) {
+ // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
+ ggml_vec_dot_f32(nx, &beta, 0, &lm_y[j[0]*nx], 0, d, 0, 1);
+ beta /= lm_ys[j[0]];
+ // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
+ ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
+ j[0] = (j[0] + 1)%m;
+ }
+
+ step[0] = 1.0;
+ }
+
+ GGML_ASSERT(false && "lbfgs failed");
+
+ return GGML_OPT_RESULT_DID_NOT_CONVERGE;
+}
+
+struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
+ struct ggml_opt_params result;
+
+ switch (type) {
+ case GGML_OPT_TYPE_ADAM:
+ {
+ result = (struct ggml_opt_params) {
+ .type = GGML_OPT_TYPE_ADAM,
+ .graph_size = GGML_DEFAULT_GRAPH_SIZE,
+ .n_threads = 1, // FIXME: GGML_DEFAULT_N_THREADS ?
+ .past = 0,
+ .delta = 1e-5f,
+
+ .max_no_improvement = 100,
+
+ .print_forward_graph = true,
+ .print_backward_graph = true,
+
+ .n_gradient_accumulation = 1,
+
+ .adam = {
+ .n_iter = 10000,
+ .sched = 1.000f,
+ .decay = 0.0f,
+ .decay_min_ndim = 2,
+ .alpha = 0.001f,
+ .beta1 = 0.9f,
+ .beta2 = 0.999f,
+ .eps = 1e-8f,
+ .eps_f = 1e-5f,
+ .eps_g = 1e-3f,
+ .gclip = 0.0f,
+ },
+ };
+ } break;
+ case GGML_OPT_TYPE_LBFGS:
+ {
+ result = (struct ggml_opt_params) {
+ .type = GGML_OPT_TYPE_LBFGS,
+ .graph_size = GGML_DEFAULT_GRAPH_SIZE,
+ .n_threads = 1,
+ .past = 0,
+ .delta = 1e-5f,
+
+ .max_no_improvement = 0,
+
+ .print_forward_graph = true,
+ .print_backward_graph = true,
+
+ .n_gradient_accumulation = 1,
+
+ .lbfgs = {
+ .m = 6,
+ .n_iter = 100,
+ .max_linesearch = 20,
+
+ .eps = 1e-5f,
+ .ftol = 1e-4f,
+ .wolfe = 0.9f,
+ .min_step = 1e-20f,
+ .max_step = 1e+20f,
+
+ .linesearch = GGML_LINESEARCH_DEFAULT,
+ },
+ };
+ } break;
+ }
+
+ return result;
+}
+
+GGML_API void ggml_opt_init(
+ struct ggml_context * ctx,
+ struct ggml_opt_context * opt,
+ struct ggml_opt_params params,
+ int64_t nx) {
+ opt->ctx = ctx;
+ opt->params = params;
+ opt->iter = 0;
+ opt->nx = nx;
+ opt->just_initialized = true;
+ if (opt->ctx == NULL) {
+ struct ggml_init_params ctx_opt_params;
+ if (opt->params.type == GGML_OPT_TYPE_ADAM) {
+ ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3;
+ if (opt->params.past > 0) {
+ ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
+ }
+ } else if (opt->params.type == GGML_OPT_TYPE_LBFGS) {
+ ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2);
+ if (opt->params.past > 0) {
+ ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
+ }
+ }
+ ctx_opt_params.mem_buffer = NULL;
+ ctx_opt_params.no_alloc = false;
+
+ opt->ctx = ggml_init(ctx_opt_params);
+ }
+ switch (opt->params.type) {
+ case GGML_OPT_TYPE_ADAM:
+ {
+ opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->adam.pf = params.past > 0
+ ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
+ : NULL;
+ ggml_set_zero(opt->adam.m);
+ ggml_set_zero(opt->adam.v);
+ if (opt->adam.pf) {
+ ggml_set_zero(opt->adam.pf);
+ }
+ } break;
+ case GGML_OPT_TYPE_LBFGS:
+ {
+ opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
+ opt->lbfgs.pf = params.past > 0
+ ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
+ : NULL;
+ opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
+ opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
+ opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
+ opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
+ ggml_set_zero(opt->lbfgs.x);
+ ggml_set_zero(opt->lbfgs.xp);
+ ggml_set_zero(opt->lbfgs.g);
+ ggml_set_zero(opt->lbfgs.gp);
+ ggml_set_zero(opt->lbfgs.d);
+ if (opt->lbfgs.pf) {
+ ggml_set_zero(opt->lbfgs.pf);
+ }
+ ggml_set_zero(opt->lbfgs.lmal);
+ ggml_set_zero(opt->lbfgs.lmys);
+ ggml_set_zero(opt->lbfgs.lms);
+ ggml_set_zero(opt->lbfgs.lmy);
+ } break;
+ }
+}
+
+enum ggml_opt_result ggml_opt(
+ struct ggml_context * ctx,
+ struct ggml_opt_params params,
+ struct ggml_tensor * f) {
+ bool free_ctx = false;
+ if (ctx == NULL) {
+ struct ggml_init_params params_ctx = {
+ .mem_size = 16*1024*1024,
+ .mem_buffer = NULL,
+ .no_alloc = false,
+ };
+
+ ctx = ggml_init(params_ctx);
+ if (ctx == NULL) {
+ return GGML_OPT_RESULT_NO_CONTEXT;
+ }
+
+ free_ctx = true;
+ }
+
+ enum ggml_opt_result result = GGML_OPT_RESULT_OK;
+
+ struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
+
+ ggml_opt_init(ctx, opt, params, 0);
+ result = ggml_opt_resume(ctx, opt, f);
+
+ if (free_ctx) {
+ ggml_free(ctx);
+ }
+
+ return result;
+}
+
+enum ggml_opt_result ggml_opt_resume(
+ struct ggml_context * ctx,
+ struct ggml_opt_context * opt,
+ struct ggml_tensor * f) {
+
+ // build forward + backward compute graphs
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx, opt->params.graph_size, true);
+ ggml_build_forward_expand(gf, f);
+
+ struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf);
+ ggml_build_backward_expand(ctx, gf, gb, true);
+
+ return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
+}
+
+enum ggml_opt_result ggml_opt_resume_g(
+ struct ggml_context * ctx,
+ struct ggml_opt_context * opt,
+ struct ggml_tensor * f,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb,
+ ggml_opt_callback callback,
+ void * callback_data) {
+
+ // build forward + backward compute graphs
+ enum ggml_opt_result result = GGML_OPT_RESULT_OK;
+
+ switch (opt->params.type) {
+ case GGML_OPT_TYPE_ADAM:
+ {
+ result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb, callback, callback_data);
+ } break;
+ case GGML_OPT_TYPE_LBFGS:
+ {
+ result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb, callback, callback_data);
+ } break;
+ }
+
+ if (opt->params.print_forward_graph) {
+ ggml_graph_print (gf);
+ ggml_graph_dump_dot(gf, NULL, "opt-forward.dot");
+ }
+
+ if (opt->params.print_backward_graph) {
+ ggml_graph_print (gb);
+ ggml_graph_dump_dot(gb, gf, "opt-backward.dot");
+ }
+
+ return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_set_input(struct ggml_tensor * tensor) {
+ tensor->flags |= GGML_TENSOR_FLAG_INPUT;
+}
+
+void ggml_set_output(struct ggml_tensor * tensor) {
+ tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_quantize_init(enum ggml_type type) {
+ ggml_critical_section_start();
+
+ switch (type) {
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ1_M: iq2xs_init_impl(type); break;
+ case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
+ case GGML_TYPE_IQ3_S: iq3xs_init_impl(512); break;
+ default: // nothing
+ break;
+ }
+
+ ggml_critical_section_end();
+}
+
+void ggml_quantize_free(void) {
+ ggml_critical_section_start();
+
+ iq2xs_free_impl(GGML_TYPE_IQ2_XXS);
+ iq2xs_free_impl(GGML_TYPE_IQ2_XS);
+ iq2xs_free_impl(GGML_TYPE_IQ1_S);
+ iq3xs_free_impl(256);
+
+ ggml_critical_section_end();
+}
+
+bool ggml_quantize_requires_imatrix(enum ggml_type type) {
+ return
+ type == GGML_TYPE_IQ2_XXS ||
+ type == GGML_TYPE_IQ2_XS ||
+ type == GGML_TYPE_IQ1_S;// ||
+ //type == GGML_TYPE_IQ1_M;
+}
+
+size_t ggml_quantize_chunk(
+ enum ggml_type type,
+ const float * src,
+ void * dst,
+ int64_t start,
+ int64_t nrows,
+ int64_t n_per_row,
+ const float * imatrix) {
+ const int64_t n = (int64_t) nrows * n_per_row;
+
+ if (ggml_quantize_requires_imatrix(type)) {
+ GGML_ASSERT(imatrix != NULL);
+ }
+
+ GGML_ASSERT(start % type_traits[type].blck_size == 0);
+ GGML_ASSERT(start % n_per_row == 0);
+
+ ggml_quantize_init(type); // this is noop if already initialized
+
+ const size_t start_row = start / n_per_row;
+ const size_t row_size = ggml_row_size(type, n_per_row);
+
+ size_t result = 0;
+
+ switch (type) {
+ case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ3_S: result = quantize_iq3_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ2_S: result = quantize_iq2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ1_BN: result = quantize_iq1_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_F16:
+ {
+ size_t elemsize = sizeof(ggml_fp16_t);
+ ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
+ result = n * elemsize;
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ size_t elemsize = sizeof(ggml_bf16_t);
+ ggml_fp32_to_bf16_row(src + start, (ggml_bf16_t *)dst + start, n);
+ result = n * elemsize;
+ } break;
+ case GGML_TYPE_F32:
+ {
+ size_t elemsize = sizeof(float);
+ result = n * elemsize;
+ memcpy((uint8_t *)dst + start * elemsize, src + start, result);
+ } break;
+ default:
+ assert(false);
+ }
+
+ GGML_ASSERT(result == nrows * row_size);
+
+ return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct gguf_str {
+ uint64_t n; // GGUFv2
+ char * data;
+};
+
+static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = {
+ [GGUF_TYPE_UINT8] = sizeof(uint8_t),
+ [GGUF_TYPE_INT8] = sizeof(int8_t),
+ [GGUF_TYPE_UINT16] = sizeof(uint16_t),
+ [GGUF_TYPE_INT16] = sizeof(int16_t),
+ [GGUF_TYPE_UINT32] = sizeof(uint32_t),
+ [GGUF_TYPE_INT32] = sizeof(int32_t),
+ [GGUF_TYPE_FLOAT32] = sizeof(float),
+ [GGUF_TYPE_BOOL] = sizeof(bool),
+ [GGUF_TYPE_STRING] = sizeof(struct gguf_str),
+ [GGUF_TYPE_UINT64] = sizeof(uint64_t),
+ [GGUF_TYPE_INT64] = sizeof(int64_t),
+ [GGUF_TYPE_FLOAT64] = sizeof(double),
+ [GGUF_TYPE_ARRAY] = 0, // undefined
+};
+static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
+
+static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = {
+ [GGUF_TYPE_UINT8] = "u8",
+ [GGUF_TYPE_INT8] = "i8",
+ [GGUF_TYPE_UINT16] = "u16",
+ [GGUF_TYPE_INT16] = "i16",
+ [GGUF_TYPE_UINT32] = "u32",
+ [GGUF_TYPE_INT32] = "i32",
+ [GGUF_TYPE_FLOAT32] = "f32",
+ [GGUF_TYPE_BOOL] = "bool",
+ [GGUF_TYPE_STRING] = "str",
+ [GGUF_TYPE_ARRAY] = "arr",
+ [GGUF_TYPE_UINT64] = "u64",
+ [GGUF_TYPE_INT64] = "i64",
+ [GGUF_TYPE_FLOAT64] = "f64",
+};
+static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
+
+union gguf_value {
+ uint8_t uint8;
+ int8_t int8;
+ uint16_t uint16;
+ int16_t int16;
+ uint32_t uint32;
+ int32_t int32;
+ float float32;
+ uint64_t uint64;
+ int64_t int64;
+ double float64;
+ bool bool_;
+
+ struct gguf_str str;
+
+ struct {
+ enum gguf_type type;
+
+ uint64_t n; // GGUFv2
+ void * data;
+ } arr;
+};
+
+struct gguf_kv {
+ struct gguf_str key;
+
+ enum gguf_type type;
+ union gguf_value value;
+};
+
+struct gguf_header {
+ char magic[4];
+
+ uint32_t version;
+ uint64_t n_tensors; // GGUFv2
+ uint64_t n_kv; // GGUFv2
+};
+
+struct gguf_tensor_info {
+ struct gguf_str name;
+
+ uint32_t n_dims;
+ uint64_t ne[GGML_MAX_DIMS];
+
+ enum ggml_type type;
+
+ uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT`
+
+ // for writing API
+ const void * data;
+ size_t size;
+};
+
+struct gguf_context {
+ struct gguf_header header;
+
+ struct gguf_kv * kv;
+ struct gguf_tensor_info * infos;
+
+ size_t alignment;
+ size_t offset; // offset of `data` from beginning of file
+ size_t size; // size of `data` in bytes
+
+ //uint8_t * padding;
+ void * data;
+};
+
+static size_t gguf_type_size(enum gguf_type type) {
+ GGML_ASSERT(0 <= type && type < GGUF_TYPE_COUNT);
+ return GGUF_TYPE_SIZE[type];
+}
+
+static void gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
+ GGML_ASSERT(info->n_dims <= GGML_MAX_DIMS);
+ GGML_ASSERT(0 <= info->type && info->type < GGML_TYPE_COUNT);
+
+ for (uint32_t i = 0; i < info->n_dims; ++i) {
+ GGML_ASSERT(info->ne[i] > 0);
+ }
+
+ // prevent overflow for total number of elements
+ GGML_ASSERT(INT64_MAX/info->ne[1] > info->ne[0]);
+ GGML_ASSERT(INT64_MAX/info->ne[2] > info->ne[0]*info->ne[1]);
+ GGML_ASSERT(INT64_MAX/info->ne[3] > info->ne[0]*info->ne[1]*info->ne[2]);
+}
+
+static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
+ const size_t n = fread(dst, 1, size, file);
+ *offset += n;
+ return n == size;
+}
+
+static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) {
+ p->n = 0;
+ p->data = NULL;
+
+ bool ok = true;
+
+ ok = ok && gguf_fread_el(file, &p->n, sizeof(p->n), offset);
+
+ // early exit if string length is invalid, prevents from integer overflow
+ if (p->n == SIZE_MAX) {
+ fprintf(stderr, "%s: invalid string length (%" PRIu64 ")\n", __func__, p->n);
+ return false;
+ }
+
+ p->data = GGML_CALLOC(p->n + 1, 1);
+
+ ok = ok && gguf_fread_el(file, p->data, p->n, offset);
+
+ return ok;
+}
+
+static void gguf_free_kv(struct gguf_kv * kv) {
+ if (kv->key.data) {
+ GGML_FREE(kv->key.data);
+ }
+
+ if (kv->type == GGUF_TYPE_STRING) {
+ if (kv->value.str.data) {
+ GGML_FREE(kv->value.str.data);
+ }
+ }
+
+ if (kv->type == GGUF_TYPE_ARRAY) {
+ if (kv->value.arr.data) {
+ if (kv->value.arr.type == GGUF_TYPE_STRING) {
+ for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
+ struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j];
+ if (str->data) {
+ GGML_FREE(str->data);
+ }
+ }
+ }
+ GGML_FREE(kv->value.arr.data);
+ }
+ }
+}
+
+struct gguf_context * gguf_init_empty(void) {
+ struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
+
+ memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
+ ctx->header.version = GGUF_VERSION;
+ ctx->header.n_tensors = 0;
+ ctx->header.n_kv = 0;
+
+ ctx->kv = NULL;
+ ctx->infos = NULL;
+
+ ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
+ ctx->offset = 0;
+ ctx->size = 0;
+
+ ctx->data = NULL;
+
+ return ctx;
+}
+
+struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
+ FILE * file = ggml_fopen(fname, "rb");
+ if (!file) {
+ fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno));
+ return NULL;
+ }
+
+ // offset from start of file
+ size_t offset = 0;
+
+ char magic[4];
+
+ // check the magic before making allocations
+ {
+ gguf_fread_el(file, &magic, sizeof(magic), &offset);
+
+ for (uint32_t i = 0; i < sizeof(magic); i++) {
+ if (magic[i] != GGUF_MAGIC[i]) {
+ fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
+ fclose(file);
+ return NULL;
+ }
+ }
+ }
+
+ bool ok = true;
+
+ struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context));
+
+ // read the header
+ {
+ strncpy(ctx->header.magic, magic, 4);
+
+ ctx->kv = NULL;
+ ctx->infos = NULL;
+ ctx->data = NULL;
+
+ ok = ok && gguf_fread_el(file, &ctx->header.version, sizeof(ctx->header.version), &offset);
+ ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
+ ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset);
+
+ if (ctx->header.version == 1) {
+ fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__);
+ fclose(file);
+ gguf_free(ctx);
+ return NULL;
+ }
+
+ // sanity-checks to prevent from integer/buffer overflows
+
+ ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/sizeof(struct gguf_tensor_info));
+ ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/ggml_tensor_overhead());
+ ok = ok && (ctx->header.n_kv < (SIZE_MAX/2)/sizeof(struct gguf_kv));
+
+ if (!ok) {
+ fprintf(stderr, "%s: failed to read header\n", __func__);
+ fclose(file);
+ gguf_free(ctx);
+ return NULL;
+ }
+ }
+
+ // read the kv pairs
+ {
+ const uint64_t n_kv = ctx->header.n_kv;
+
+ // header.n_kv will hold the actual value of pairs that were successfully read in the loop below
+ ctx->header.n_kv = 0;
+ ctx->kv = GGML_CALLOC(n_kv, sizeof(struct gguf_kv));
+
+ for (uint64_t i = 0; i < n_kv; ++i) {
+ struct gguf_kv * kv = &ctx->kv[i];
+
+ //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
+
+ ok = ok && gguf_fread_str(file, &kv->key, &offset);
+ ok = ok && gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset);
+
+ //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data);
+
+ switch (kv->type) {
+ case GGUF_TYPE_UINT8: ok = ok && gguf_fread_el (file, &kv->value.uint8, sizeof(kv->value.uint8), &offset); break;
+ case GGUF_TYPE_INT8: ok = ok && gguf_fread_el (file, &kv->value.int8, sizeof(kv->value.int8), &offset); break;
+ case GGUF_TYPE_UINT16: ok = ok && gguf_fread_el (file, &kv->value.uint16, sizeof(kv->value.uint16), &offset); break;
+ case GGUF_TYPE_INT16: ok = ok && gguf_fread_el (file, &kv->value.int16, sizeof(kv->value.int16), &offset); break;
+ case GGUF_TYPE_UINT32: ok = ok && gguf_fread_el (file, &kv->value.uint32, sizeof(kv->value.uint32), &offset); break;
+ case GGUF_TYPE_INT32: ok = ok && gguf_fread_el (file, &kv->value.int32, sizeof(kv->value.int32), &offset); break;
+ case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break;
+ case GGUF_TYPE_UINT64: ok = ok && gguf_fread_el (file, &kv->value.uint64, sizeof(kv->value.uint64), &offset); break;
+ case GGUF_TYPE_INT64: ok = ok && gguf_fread_el (file, &kv->value.int64, sizeof(kv->value.int64), &offset); break;
+ case GGUF_TYPE_FLOAT64: ok = ok && gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break;
+ case GGUF_TYPE_BOOL: ok = ok && gguf_fread_el (file, &kv->value.bool_, sizeof(kv->value.bool_), &offset); break;
+ case GGUF_TYPE_STRING: ok = ok && gguf_fread_str(file, &kv->value.str, &offset); break;
+ case GGUF_TYPE_ARRAY:
+ {
+ ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset);
+ ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset);
+
+ switch (kv->value.arr.type) {
+ case GGUF_TYPE_UINT8:
+ case GGUF_TYPE_INT8:
+ case GGUF_TYPE_UINT16:
+ case GGUF_TYPE_INT16:
+ case GGUF_TYPE_UINT32:
+ case GGUF_TYPE_INT32:
+ case GGUF_TYPE_FLOAT32:
+ case GGUF_TYPE_UINT64:
+ case GGUF_TYPE_INT64:
+ case GGUF_TYPE_FLOAT64:
+ case GGUF_TYPE_BOOL:
+ {
+ // prevent from integer overflow in the malloc below
+ if (kv->value.arr.n >= SIZE_MAX/gguf_type_size(kv->value.arr.type)) {
+ fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
+ fclose(file);
+ gguf_free(ctx);
+ return NULL;
+ }
+
+ kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
+
+ ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
+ } break;
+ case GGUF_TYPE_STRING:
+ {
+ // prevent from integer overflow in the malloc below
+ if (kv->value.arr.n >= SIZE_MAX/sizeof(struct gguf_str)) {
+ fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
+ fclose(file);
+ gguf_free(ctx);
+ return NULL;
+ }
+
+ kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, sizeof(struct gguf_str));
+
+ for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
+ ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
+ }
+ } break;
+ case GGUF_TYPE_ARRAY:
+ default: GGML_ASSERT(false && "invalid type"); break;
+ }
+ } break;
+ default: GGML_ASSERT(false && "invalid type");
+ }
+
+ if (!ok) {
+ break;
+ }
+
+ ctx->header.n_kv++;
+ }
+
+ if (!ok) {
+ fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
+ fclose(file);
+ gguf_free(ctx);
+ return NULL;
+ }
+ }
+
+ // read the tensor infos
+ if (ctx->header.n_tensors > 0) {
+ ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
+
+ for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
+ struct gguf_tensor_info * info = &ctx->infos[i];
+
+ for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+ info->ne[j] = 1;
+ }
+
+ ok = ok && gguf_fread_str(file, &info->name, &offset);
+ ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims), &offset);
+
+ ok = ok && (info->n_dims <= GGML_MAX_DIMS);
+
+ for (uint32_t j = 0; j < info->n_dims; ++j) {
+ ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
+ }
+
+ ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
+ ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);
+
+ // TODO: return an error instead of crashing with GGML_ASSERT
+ gguf_tensor_info_sanitize(info);
+
+ // make sure there is no duplicated tensor names
+ for (uint64_t j = 0; j < i && ok; ++j) {
+ if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) {
+ fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data);
+ ok = false;
+ }
+ }
+
+ if (!ok) {
+ fprintf(stderr, "%s: failed to read tensor info\n", __func__);
+ fclose(file);
+ gguf_free(ctx);
+ return NULL;
+ }
+ }
+ }
+
+ ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
+
+ int alignment_idx = gguf_find_key(ctx, "general.alignment");
+ if (alignment_idx != -1) {
+ ctx->alignment = gguf_get_val_u32(ctx, alignment_idx);
+ }
+
+ // we require the data section to be aligned, so take into account any padding
+ {
+ const size_t offset_pad = offset % ctx->alignment;
+
+ if (offset_pad != 0) {
+ offset += ctx->alignment - offset_pad;
+ fseek(file, offset, SEEK_SET);
+ }
+ }
+
+ // store the current file offset - this is where the data section starts
+ ctx->offset = offset;
+
+ // compute the total size of the data section, taking into account the alignment
+ {
+ ctx->size = 0;
+ for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
+ struct gguf_tensor_info * info = &ctx->infos[i];
+
+ const int64_t ne =
+ (int64_t) info->ne[0] *
+ (int64_t) info->ne[1] *
+ (int64_t) info->ne[2] *
+ (int64_t) info->ne[3];
+
+ if (ne % ggml_blck_size(info->type) != 0) {
+ fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
+ __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
+ fclose(file);
+ gguf_free(ctx);
+ return NULL;
+ }
+
+ const size_t size_cur = ggml_row_size(info->type, ne);
+
+ ctx->size += GGML_PAD(size_cur, ctx->alignment);
+ }
+ }
+
+ // load the tensor data only if requested
+ if (params.ctx != NULL) {
+ // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob
+ // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of
+ // the ggml_tensor structs to the appropriate locations in the binary blob
+
+ // compute the exact size needed for the new ggml_context
+ const size_t mem_size =
+ params.no_alloc ?
+ (ctx->header.n_tensors )*ggml_tensor_overhead() :
+ (ctx->header.n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
+
+ struct ggml_init_params pdata = {
+ .mem_size = mem_size,
+ .mem_buffer = NULL,
+ .no_alloc = params.no_alloc,
+ };
+
+ *params.ctx = ggml_init(pdata);
+ if (*params.ctx == NULL) {
+ fprintf(stderr, "%s: failed to initialize context\n", __func__);
+ fclose(file);
+ gguf_free(ctx);
+ return NULL;
+ }
+
+ struct ggml_context * ctx_data = *params.ctx;
+
+ struct ggml_tensor * data = NULL;
+
+ if (!params.no_alloc) {
+ data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
+
+ ok = ok && data != NULL;
+
+ // read the binary blob with the tensor data
+ ok = ok && gguf_fread_el(file, data->data, ctx->size, &offset);
+
+ if (!ok) {
+ fprintf(stderr, "%s: failed to read tensor data\n", __func__);
+ fclose(file);
+ ggml_free(ctx_data);
+ gguf_free(ctx);
+ return NULL;
+ }
+
+ ctx->data = data->data;
+ }
+
+ ggml_set_no_alloc(ctx_data, true);
+
+ // create the tensors
+ for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
+ const int64_t ne[GGML_MAX_DIMS] = {
+ ctx->infos[i].ne[0],
+ ctx->infos[i].ne[1],
+ ctx->infos[i].ne[2],
+ ctx->infos[i].ne[3],
+ };
+
+ struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne);
+
+ ok = ok && cur != NULL;
+
+ if (!ok) {
+ break;
+ }
+
+ ggml_set_name(cur, ctx->infos[i].name.data);
+
+ // point the data member to the appropriate location in the binary blob using the tensor infos
+ if (!params.no_alloc) {
+ //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
+ cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data
+ }
+ }
+
+ if (!ok) {
+ fprintf(stderr, "%s: failed to read the tensor data\n", __func__);
+ fclose(file);
+ ggml_free(ctx_data);
+ gguf_free(ctx);
+ return NULL;
+ }
+
+ ggml_set_no_alloc(ctx_data, params.no_alloc);
+ }
+
+ fclose(file);
+
+ return ctx;
+}
+
+void gguf_free(struct gguf_context * ctx) {
+ if (ctx == NULL) {
+ return;
+ }
+
+ if (ctx->kv) {
+ // free string memory - not great..
+ for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
+ gguf_free_kv(&ctx->kv[i]);
+ }
+
+ GGML_FREE(ctx->kv);
+ }
+
+ if (ctx->infos) {
+ for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
+ struct gguf_tensor_info * info = &ctx->infos[i];
+
+ if (info->name.data) {
+ GGML_FREE(info->name.data);
+ }
+ }
+
+ GGML_FREE(ctx->infos);
+ }
+
+ GGML_FREE(ctx);
+}
+
+const char * gguf_type_name(enum gguf_type type) {
+ return GGUF_TYPE_NAME[type];
+}
+
+int gguf_get_version(const struct gguf_context * ctx) {
+ return ctx->header.version;
+}
+
+size_t gguf_get_alignment(const struct gguf_context * ctx) {
+ return ctx->alignment;
+}
+
+size_t gguf_get_data_offset(const struct gguf_context * ctx) {
+ return ctx->offset;
+}
+
+void * gguf_get_data(const struct gguf_context * ctx) {
+ return ctx->data;
+}
+
+int gguf_get_n_kv(const struct gguf_context * ctx) {
+ return ctx->header.n_kv;
+}
+
+int gguf_find_key(const struct gguf_context * ctx, const char * key) {
+ // return -1 if key not found
+ int keyfound = -1;
+
+ const int n_kv = gguf_get_n_kv(ctx);
+
+ for (int i = 0; i < n_kv; ++i) {
+ if (strcmp(key, gguf_get_key(ctx, i)) == 0) {
+ keyfound = i;
+ break;
+ }
+ }
+
+ return keyfound;
+}
+
+const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ return ctx->kv[key_id].key.data;
+}
+
+enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ return ctx->kv[key_id].type;
+}
+
+enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
+ return ctx->kv[key_id].value.arr.type;
+}
+
+const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
+ return ctx->kv[key_id].value.arr.data;
+}
+
+const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
+ struct gguf_kv * kv = &ctx->kv[key_id];
+ struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
+ return str->data;
+}
+
+int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
+ return ctx->kv[key_id].value.arr.n;
+}
+
+uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
+ return ctx->kv[key_id].value.uint8;
+}
+
+int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
+ return ctx->kv[key_id].value.int8;
+}
+
+uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
+ return ctx->kv[key_id].value.uint16;
+}
+
+int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
+ return ctx->kv[key_id].value.int16;
+}
+
+uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
+ return ctx->kv[key_id].value.uint32;
+}
+
+int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
+ return ctx->kv[key_id].value.int32;
+}
+
+float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
+ return ctx->kv[key_id].value.float32;
+}
+
+uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
+ return ctx->kv[key_id].value.uint64;
+}
+
+int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
+ return ctx->kv[key_id].value.int64;
+}
+
+double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
+ return ctx->kv[key_id].value.float64;
+}
+
+bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
+ return ctx->kv[key_id].value.bool_;
+}
+
+const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
+ return ctx->kv[key_id].value.str.data;
+}
+
+const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
+ GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING);
+ return &ctx->kv[key_id].value;
+}
+
+int gguf_get_n_tensors(const struct gguf_context * ctx) {
+ return ctx->header.n_tensors;
+}
+
+int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
+ // return -1 if tensor not found
+ int tensorfound = -1;
+
+ const int n_tensors = gguf_get_n_tensors(ctx);
+
+ for (int i = 0; i < n_tensors; ++i) {
+ if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) {
+ tensorfound = i;
+ break;
+ }
+ }
+
+ return tensorfound;
+}
+
+size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) {
+ return ctx->infos[i].offset;
+}
+
+char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) {
+ return ctx->infos[i].name.data;
+}
+
+enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int i) {
+ return ctx->infos[i].type;
+}
+
+// returns the index
+static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) {
+ const int idx = gguf_find_key(ctx, key);
+ if (idx >= 0) {
+ return idx;
+ }
+
+ const int n_kv = gguf_get_n_kv(ctx);
+
+ ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv));
+ ctx->kv[n_kv].key.n = strlen(key);
+ ctx->kv[n_kv].key.data = strdup(key);
+ ctx->header.n_kv++;
+
+ return n_kv;
+}
+
+void gguf_remove_key(struct gguf_context * ctx, const char * key) {
+ const int idx = gguf_find_key(ctx, key);
+ if (idx >= 0) {
+ const int n_kv = gguf_get_n_kv(ctx);
+ gguf_free_kv(&ctx->kv[idx]);
+ for (int i = idx; i < n_kv-1; ++i) {
+ ctx->kv[i] = ctx->kv[i+1];
+ }
+ ctx->kv = realloc(ctx->kv, (n_kv - 1) * sizeof(struct gguf_kv));
+ ctx->header.n_kv--;
+ }
+}
+
+void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) {
+ const int idx = gguf_get_or_add_key(ctx, key);
+
+ ctx->kv[idx].type = GGUF_TYPE_UINT8;
+ ctx->kv[idx].value.uint8 = val;
+}
+
+void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) {
+ const int idx = gguf_get_or_add_key(ctx, key);
+
+ ctx->kv[idx].type = GGUF_TYPE_INT8;
+ ctx->kv[idx].value.int8 = val;
+}
+
+void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) {
+ const int idx = gguf_get_or_add_key(ctx, key);
+
+ ctx->kv[idx].type = GGUF_TYPE_UINT16;
+ ctx->kv[idx].value.uint16 = val;
+}
+
+void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) {
+ const int idx = gguf_get_or_add_key(ctx, key);
+
+ ctx->kv[idx].type = GGUF_TYPE_INT16;
+ ctx->kv[idx].value.int16 = val;
+}
+
+void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) {
+ const int idx = gguf_get_or_add_key(ctx, key);
+
+ ctx->kv[idx].type = GGUF_TYPE_UINT32;
+ ctx->kv[idx].value.uint32 = val;
+}
+
+void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) {
+ const int idx = gguf_get_or_add_key(ctx, key);
+
+ ctx->kv[idx].type = GGUF_TYPE_INT32;
+ ctx->kv[idx].value.int32 = val;
+}
+
+void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) {
+ const int idx = gguf_get_or_add_key(ctx, key);
+
+ ctx->kv[idx].type = GGUF_TYPE_FLOAT32;
+ ctx->kv[idx].value.float32 = val;
+}
+
+void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) {
+ const int idx = gguf_get_or_add_key(ctx, key);
+
+ ctx->kv[idx].type = GGUF_TYPE_UINT64;
+ ctx->kv[idx].value.uint64 = val;
+}
+
+void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) {
+ const int idx = gguf_get_or_add_key(ctx, key);
+
+ ctx->kv[idx].type = GGUF_TYPE_INT64;
+ ctx->kv[idx].value.int64 = val;
+}
+
+void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) {
+ const int idx = gguf_get_or_add_key(ctx, key);
+
+ ctx->kv[idx].type = GGUF_TYPE_FLOAT64;
+ ctx->kv[idx].value.float64 = val;
+}
+
+void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) {
+ const int idx = gguf_get_or_add_key(ctx, key);
+
+ ctx->kv[idx].type = GGUF_TYPE_BOOL;
+ ctx->kv[idx].value.bool_ = val;
+}
+
+void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) {
+ const int idx = gguf_get_or_add_key(ctx, key);
+
+ ctx->kv[idx].type = GGUF_TYPE_STRING;
+ ctx->kv[idx].value.str.n = strlen(val);
+ ctx->kv[idx].value.str.data = strdup(val);
+}
+
+void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n) {
+ const int idx = gguf_get_or_add_key(ctx, key);
+
+ ctx->kv[idx].type = GGUF_TYPE_ARRAY;
+ ctx->kv[idx].value.arr.type = type;
+ ctx->kv[idx].value.arr.n = n;
+ ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type));
+ memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type));
+}
+
+void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, int n) {
+ const int idx = gguf_get_or_add_key(ctx, key);
+
+ ctx->kv[idx].type = GGUF_TYPE_ARRAY;
+ ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING;
+ ctx->kv[idx].value.arr.n = n;
+ ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str));
+ for (int i = 0; i < n; i++) {
+ struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i];
+ str->n = strlen(data[i]);
+ str->data = strdup(data[i]);
+ }
+}
+
+// set or add KV pairs from another context
+void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
+ for (uint32_t i = 0; i < src->header.n_kv; i++) {
+ switch (src->kv[i].type) {
+ case GGUF_TYPE_UINT8: gguf_set_val_u8 (ctx, src->kv[i].key.data, src->kv[i].value.uint8); break;
+ case GGUF_TYPE_INT8: gguf_set_val_i8 (ctx, src->kv[i].key.data, src->kv[i].value.int8); break;
+ case GGUF_TYPE_UINT16: gguf_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16); break;
+ case GGUF_TYPE_INT16: gguf_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16); break;
+ case GGUF_TYPE_UINT32: gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32); break;
+ case GGUF_TYPE_INT32: gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32); break;
+ case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32); break;
+ case GGUF_TYPE_UINT64: gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64); break;
+ case GGUF_TYPE_INT64: gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64); break;
+ case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64); break;
+ case GGUF_TYPE_BOOL: gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_); break;
+ case GGUF_TYPE_STRING: gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break;
+ case GGUF_TYPE_ARRAY:
+ {
+ if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) {
+ const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *));
+ for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
+ data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
+ }
+ gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n);
+ GGML_FREE((void *)data);
+ } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) {
+ GGML_ASSERT(false && "nested arrays not supported");
+ } else {
+ gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n);
+ }
+ } break;
+ default: GGML_ASSERT(false && "invalid type"); break;
+ }
+ }
+}
+
+void gguf_add_tensor(
+ struct gguf_context * ctx,
+ const struct ggml_tensor * tensor) {
+ if (gguf_find_tensor(ctx, tensor->name) != -1) {
+ GGML_ASSERT(false && "duplicated tensor name");
+ }
+
+ const int idx = ctx->header.n_tensors;
+ ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
+
+ ctx->infos[idx].name.n = strlen(tensor->name);
+ ctx->infos[idx].name.data = strdup(tensor->name);
+
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+ ctx->infos[idx].ne[i] = 1;
+ }
+
+ ctx->infos[idx].n_dims = ggml_n_dims(tensor);
+ for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) {
+ ctx->infos[idx].ne[i] = tensor->ne[i];
+ }
+
+ ctx->infos[idx].type = tensor->type;
+ ctx->infos[idx].offset = 0;
+ ctx->infos[idx].data = tensor->data;
+ ctx->infos[idx].size = ggml_nbytes(tensor);
+
+ if (ctx->header.n_tensors > 0) {
+ ctx->infos[idx].offset = ctx->infos[idx - 1].offset + GGML_PAD(ctx->infos[idx - 1].size, ctx->alignment);
+ }
+
+ ctx->header.n_tensors++;
+}
+
+void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) {
+ const int idx = gguf_find_tensor(ctx, name);
+ if (idx < 0) {
+ GGML_ASSERT(false && "tensor not found");
+ }
+
+ ctx->infos[idx].type = type;
+}
+
+void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size) {
+ const int idx = gguf_find_tensor(ctx, name);
+ if (idx < 0) {
+ GGML_ASSERT(false && "tensor not found");
+ }
+
+ ctx->infos[idx].data = data;
+ ctx->infos[idx].size = size;
+
+ // update offsets
+ for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) {
+ ctx->infos[i].offset = ctx->infos[i - 1].offset + GGML_PAD(ctx->infos[i - 1].size, ctx->alignment);
+ }
+}
+
+//static void gguf_fwrite_str(FILE * file, const struct gguf_str * val) {
+// fwrite(&val->n, sizeof(val->n), 1, file);
+// fwrite(val->data, sizeof(char), val->n, file);
+//}
+//
+//static void gguf_fwrite_el(FILE * file, const void * val, size_t size) {
+// fwrite(val, sizeof(char), size, file);
+//}
+
+struct gguf_buf {
+ void * data;
+ size_t size;
+ size_t offset;
+};
+
+static struct gguf_buf gguf_buf_init(size_t size) {
+ struct gguf_buf buf = {
+ /*buf.data =*/ size == 0 ? NULL : GGML_CALLOC(1, size),
+ /*buf.size =*/ size,
+ /*buf.offset =*/ 0,
+ };
+
+ return buf;
+}
+
+static void gguf_buf_free(struct gguf_buf buf) {
+ if (buf.data) {
+ GGML_FREE(buf.data);
+ }
+}
+
+static void gguf_buf_grow(struct gguf_buf * buf, size_t size) {
+ if (buf->offset + size > buf->size) {
+ buf->size = 1.5*(buf->offset + size);
+ if (buf->data) {
+ buf->data = realloc(buf->data, buf->size);
+ }
+ }
+}
+
+static void gguf_bwrite_str(struct gguf_buf * buf, const struct gguf_str * val) {
+ gguf_buf_grow(buf, sizeof(val->n) + val->n);
+
+ if (buf->data) {
+ memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n));
+ }
+ buf->offset += sizeof(val->n);
+
+ if (buf->data) {
+ memcpy((char *) buf->data + buf->offset, val->data, val->n);
+ }
+ buf->offset += val->n;
+}
+
+static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_size) {
+ gguf_buf_grow(buf, el_size);
+
+ if (buf->data) {
+ memcpy((char *) buf->data + buf->offset, val, el_size);
+ }
+ buf->offset += el_size;
+}
+
+static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
+ // write header
+ gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic));
+ gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version));
+ gguf_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors));
+ gguf_bwrite_el(buf, &ctx->header.n_kv, sizeof(ctx->header.n_kv));
+
+ // write key-value pairs
+ for (uint32_t i = 0; i < ctx->header.n_kv; ++i) {
+ struct gguf_kv * kv = &ctx->kv[i];
+
+ gguf_bwrite_str(buf, &kv->key);
+ gguf_bwrite_el (buf, &kv->type, sizeof(kv->type));
+
+ switch (kv->type) {
+ case GGUF_TYPE_UINT8: gguf_bwrite_el( buf, &kv->value.uint8, sizeof(kv->value.uint8) ); break;
+ case GGUF_TYPE_INT8: gguf_bwrite_el (buf, &kv->value.int8, sizeof(kv->value.int8) ); break;
+ case GGUF_TYPE_UINT16: gguf_bwrite_el (buf, &kv->value.uint16, sizeof(kv->value.uint16) ); break;
+ case GGUF_TYPE_INT16: gguf_bwrite_el (buf, &kv->value.int16, sizeof(kv->value.int16) ); break;
+ case GGUF_TYPE_UINT32: gguf_bwrite_el (buf, &kv->value.uint32, sizeof(kv->value.uint32) ); break;
+ case GGUF_TYPE_INT32: gguf_bwrite_el (buf, &kv->value.int32, sizeof(kv->value.int32) ); break;
+ case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break;
+ case GGUF_TYPE_UINT64: gguf_bwrite_el (buf, &kv->value.uint64, sizeof(kv->value.uint64) ); break;
+ case GGUF_TYPE_INT64: gguf_bwrite_el (buf, &kv->value.int64, sizeof(kv->value.int64) ); break;
+ case GGUF_TYPE_FLOAT64: gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break;
+ case GGUF_TYPE_BOOL: gguf_bwrite_el (buf, &kv->value.bool_, sizeof(kv->value.bool_) ); break;
+ case GGUF_TYPE_STRING: gguf_bwrite_str(buf, &kv->value.str ); break;
+ case GGUF_TYPE_ARRAY:
+ {
+ gguf_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type));
+ gguf_bwrite_el(buf, &kv->value.arr.n, sizeof(kv->value.arr.n) );
+
+ switch (kv->value.arr.type) {
+ case GGUF_TYPE_UINT8:
+ case GGUF_TYPE_INT8:
+ case GGUF_TYPE_UINT16:
+ case GGUF_TYPE_INT16:
+ case GGUF_TYPE_UINT32:
+ case GGUF_TYPE_INT32:
+ case GGUF_TYPE_FLOAT32:
+ case GGUF_TYPE_UINT64:
+ case GGUF_TYPE_INT64:
+ case GGUF_TYPE_FLOAT64:
+ case GGUF_TYPE_BOOL:
+ {
+ gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type));
+ } break;
+ case GGUF_TYPE_STRING:
+ {
+ for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
+ gguf_bwrite_str(buf, &((struct gguf_str *) kv->value.arr.data)[j]);
+ }
+ } break;
+ case GGUF_TYPE_ARRAY:
+ default: GGML_ASSERT(false && "invalid type"); break;
+ }
+ } break;
+ default: GGML_ASSERT(false && "invalid type");
+ }
+ }
+
+ // write tensor infos
+ for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
+ struct gguf_tensor_info * info = &ctx->infos[i];
+
+ gguf_bwrite_str(buf, &info->name);
+ gguf_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims));
+ for (uint32_t j = 0; j < info->n_dims; ++j) {
+ gguf_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j]));
+ }
+ gguf_bwrite_el(buf, &info->type, sizeof(info->type));
+ gguf_bwrite_el(buf, &info->offset, sizeof(info->offset));
+ }
+
+ // we require the data section to be aligned, so take into account any padding
+ {
+ const size_t offset = buf->offset;
+ const size_t offset_pad = GGML_PAD(offset, ctx->alignment);
+
+ if (offset_pad != offset) {
+ uint8_t pad = 0;
+ for (size_t i = 0; i < offset_pad - offset; ++i) {
+ gguf_bwrite_el(buf, &pad, sizeof(pad));
+ }
+ }
+ }
+
+ if (only_meta) {
+ return;
+ }
+
+ size_t offset = 0;
+
+ // write tensor data
+ for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
+ struct gguf_tensor_info * info = &ctx->infos[i];
+
+ const size_t size = info->size;
+ const size_t size_pad = GGML_PAD(size, ctx->alignment);
+
+ gguf_bwrite_el(buf, info->data, size);
+
+ if (size_pad != size) {
+ uint8_t pad = 0;
+ for (size_t j = 0; j < size_pad - size; ++j) {
+ gguf_bwrite_el(buf, &pad, sizeof(pad));
+ }
+ }
+
+ GGML_ASSERT(offset == info->offset);
+
+ offset += size_pad;
+ }
+}
+
+void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
+ FILE * file = ggml_fopen(fname, "wb");
+ if (!file) {
+ GGML_ASSERT(false && "failed to open file for writing");
+ }
+
+ struct gguf_buf buf = gguf_buf_init(16*1024);
+
+ gguf_write_to_buf(ctx, &buf, only_meta);
+
+ fwrite(buf.data, 1, buf.offset, file);
+
+ gguf_buf_free(buf);
+
+ fclose(file);
+}
+
+size_t gguf_get_meta_size(const struct gguf_context * ctx) {
+ // no allocs - only compute size
+ struct gguf_buf buf = gguf_buf_init(0);
+
+ gguf_write_to_buf(ctx, &buf, true);
+
+ return buf.offset;
+}
+
+void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
+ struct gguf_buf buf = gguf_buf_init(16*1024);
+
+ gguf_write_to_buf(ctx, &buf, true);
+
+ memcpy(data, buf.data, buf.offset);
+
+ gguf_buf_free(buf);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+int ggml_cpu_has_avx(void) {
+#if defined(__AVX__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_avx_vnni(void) {
+#if defined(__AVXVNNI__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_avx2(void) {
+#if defined(__AVX2__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512(void) {
+#if defined(__AVX512F__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512_vbmi(void) {
+#if defined(__AVX512VBMI__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512_vnni(void) {
+#if defined(__AVX512VNNI__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512_bf16(void) {
+#if defined(__AVX512BF16__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_fma(void) {
+#if defined(__FMA__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_neon(void) {
+#if defined(__ARM_NEON)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_sve(void) {
+#if defined(__ARM_FEATURE_SVE)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_arm_fma(void) {
+#if defined(__ARM_FEATURE_FMA)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_metal(void) {
+#if defined(GGML_USE_METAL)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_f16c(void) {
+#if defined(__F16C__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_fp16_va(void) {
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_wasm_simd(void) {
+#if defined(__wasm_simd128__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_blas(void) {
+#if defined(GGML_USE_BLAS) || defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_SYCL)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_cuda(void) {
+#if defined(GGML_USE_CUDA)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_vulkan(void) {
+#if defined(GGML_USE_VULKAN)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_kompute(void) {
+#if defined(GGML_USE_KOMPUTE)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_sycl(void) {
+#if defined(GGML_USE_SYCL)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_rpc(void) {
+#if defined(GGML_USE_RPC)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_cann(void) {
+#if defined(GGML_USE_CANN)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_llamafile(void) {
+#if defined(GGML_USE_LLAMAFILE)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_gpublas(void) {
+ return ggml_cpu_has_cuda() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() || ggml_cpu_has_sycl();
+}
+
+int ggml_cpu_has_sse3(void) {
+#if defined(__SSE3__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_ssse3(void) {
+#if defined(__SSSE3__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_vsx(void) {
+#if defined(__POWER9_VECTOR__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+int ggml_cpu_has_matmul_int8(void) {
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////////
diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp
new file mode 100644
index 00000000..bf517504
--- /dev/null
+++ b/ggml/src/iqk/iqk_mul_mat.cpp
@@ -0,0 +1,4757 @@
+// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
+// vi: set et ft=cpp fenc=utf-8 :vi
+//
+//
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+
+#if defined IQK_IMPLEMENT
+#undef IQK_IMPLEMENT
+#endif
+
+#if defined __AVX2__ || defined __ARM_FEATURE_DOTPROD
+#define IQK_IMPLEMENT
+#endif
+
+#include <cstring>
+#include <type_traits>
+
+#if defined IQK_IMPLEMENT
+
+#include "ggml-impl.h"
+#include "ggml-quants.h"
+#include "iqk_mul_mat.h"
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+// clang-format off
+
+// This matrix - vector and matrix - matrix multiplication implementation
+// for k-quants, i-quants, and legacy quants, makes prompt processing
+// 150-350% faster (depending on quantization type) compared to mainline llama.cpp.
+// It is AVX2 and ARM_NEON only for now.
+// There are also implementations for fp16/32 x fp16/32 matrix multiplications
+// on AVX2 and fp16 x fp16 on ARM_NEON.
+//
+// Main idea is that unpacking the quants and the block scales to
+// be ready for dot products with the corresponding Q8_X quants
+// takes time. Hence, if we are performing a QX x Q8_X matrix matrix
+// multiplication (as needed for prompt processing), we can get
+// a significant speedup by reusing the unpacked QX quants and scales
+// for multiplication with several Q8_X columns.
+//
+// For fp16/fp32 matri multiplications tiling is used to improve
+// performance.
+
+#include <utility>
+#include <array>
+
+#ifdef _MSC_VER
+#define IQK_NOINLINE __declspec(noinline)
+#define IQK_ALWAYS_INLINE inline
+#else
+#define IQK_NOINLINE __attribute__((__noinline__))
+#define IQK_ALWAYS_INLINE __attribute__((__always_inline__))
+#endif
+
+namespace {
+
+typedef struct {
+ int32_t i1;
+ int32_t i2;
+} mmid_row_mapping;
+
+struct DataInfo {
+ float * s;
+ const char * cy;
+ size_t bs;
+ size_t by;
+ int cur_y = 0;
+ int ne11;
+ const mmid_row_mapping * row_mapping = nullptr;
+ size_t bs2 = 0;
+
+ inline const char * src1_row(int iy) const {
+ if (!row_mapping) return cy + (cur_y + iy)*by;
+ int i11 = row_mapping[cur_y + iy].i1 % ne11;
+ int i12 = row_mapping[cur_y + iy].i2;
+ return cy + (i11 + i12*ne11)*by;
+ }
+
+ inline void store(int ix, int iy, float result) const {
+ *(dst_row(iy) + ix) = result;
+ }
+ inline float * dst_row(int iy) const {
+ if (!row_mapping) return s + (cur_y + iy)*bs;
+ int i12 = row_mapping[cur_y + iy].i2;
+ int i1 = row_mapping[cur_y + iy].i1;
+ int i2 = i12;
+ return s + i1*bs + i2*bs2;
+ }
+};
+
+typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x);
+
+struct MulMat {
+ std::array<mul_mat_t, 8> funcs = {};
+ inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) {
+#ifdef __aarch64__
+ constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small)
+#else
+ constexpr int k_x_step = 64; // This works best on my Ryzen-7950X (but differences to other tile size are small)
+#endif
+ int ny = funcs.size();
+ while (!funcs[ny-1] && ny > 0) --ny;
+ int n_step = (nrc_y - info.cur_y)/ny;
+ if (n_step > 0) {
+ for (int ix = 0; ix < nrc_x; ix += k_x_step) {
+ auto this_info = info;
+ this_info.s += ix;
+ int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
+ for (int iy = 0; iy < n_step; ++iy) {
+ funcs[ny-1](n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);
+ this_info.cur_y += ny;
+ }
+ }
+ info.cur_y += ny * n_step;
+ }
+ int n_left = nrc_y - info.cur_y;
+ if (n_left > 0) {
+ funcs[n_left-1](n, vx, bx, info, nrc_x);
+ }
+ }
+ static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny);
+private:
+ template <typename Dequantizer> static void set_functions(MulMat& m);
+};
+
+}
+
+bool iqk_mul_mat(long Nx, long Ny, long ne00,
+ int typeA, const void * A, long strideA,
+ int typeB, const void * B, long strideB,
+ float * C, long stride_C, int ith, int nth) {
+
+ MulMat mm;
+ if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
+ return false;
+ }
+
+ auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA));
+ auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB));
+
+ auto nrc_x = (Nx + nth - 1)/nth;
+ auto first_x = ith*nrc_x;
+ if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;
+
+ DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};
+
+ mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);
+
+ return true;
+}
+
+bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
+ int typeA, const void * A, long strideA,
+ int typeB, const void * B, long strideB,
+ float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {
+ const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping;
+ assert(row_mapping != nullptr);
+
+ MulMat mm;
+ if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
+ return false;
+ }
+ auto row_size_qx = strideA*ggml_type_size(ggml_type(typeA));
+ auto row_size_qy = strideB*ggml_type_size(ggml_type(typeB));
+ int nrc_x = (Nx + nth - 1)/nth;
+ int first_x = ith*nrc_x;
+ if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;
+ DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float),
+ row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
+ mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);
+ return true;
+}
+
+namespace {
+
+inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {
+ const uint16_t * scales = (const uint16_t *)scales8;
+ const uint32_t a0 = scales[0] | (scales[1] << 16);
+ const uint32_t a1 = scales[2] | (scales[3] << 16);
+ const uint32_t a2 = scales[4] | (scales[5] << 16);
+ aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030);
+ aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030);
+ aux32[2] = a1 & 0x3f3f3f3f;
+ aux32[0] = a0 & 0x3f3f3f3f;
+}
+
+const uint64_t keven_signs[128] = {
+ 0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,
+ 0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,
+ 0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff,
+ 0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff,
+ 0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff,
+ 0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff,
+ 0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff,
+ 0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff,
+ 0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff,
+ 0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff,
+ 0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff,
+ 0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff,
+ 0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff,
+ 0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff,
+ 0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff,
+ 0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff,
+ 0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff,
+ 0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff,
+ 0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff,
+ 0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff,
+ 0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff,
+ 0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff,
+ 0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff,
+ 0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff,
+ 0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff,
+ 0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff,
+ 0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff,
+ 0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff,
+ 0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff,
+ 0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff,
+ 0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,
+ 0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,
+};
+
+}
+
+#if defined __x86_64__
+
+#if defined HAVE_FANCY_SIMD
+ #undef HAVE_FANCY_SIMD
+#endif
+#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)
+ #define HAVE_FANCY_SIMD
+#endif
+
+namespace {
+
+inline float hsum_float_4(__m128 x) {
+ x = _mm_add_ps(x, _mm_movehl_ps(x, x));
+ x = _mm_add_ss(x, _mm_movehdup_ps(x));
+ return _mm_cvtss_f32(x);
+}
+inline float hsum_float_8(__m256 x) {
+ return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
+}
+inline int hsum_i32_8(const __m256i a) {
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+
+template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);
+ }
+
+#ifdef HAVE_FANCY_SIMD
+ inline __m512i load_quants64(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); }
+#endif
+ inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); }
+ inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); }
+ inline float scale(int iy, int i) const { return y[iy][i].d; }
+
+ const block_q8 * y[nrc_y];
+};
+
+struct Scales8KBase {
+ template <typename Q8>
+ inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {
+ const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0]));
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i q8s = q8.load_bsums(iy, i);
+ const __m256i prod = _mm256_madd_epi16(mins, q8s);
+ accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
+ }
+ }
+ inline __m256i shuffle(__m128i mins) const {
+ return MM256_SET_M128I(_mm_shuffle_epi8(mins, shuffles[1]), _mm_shuffle_epi8(mins, shuffles[0]));
+ }
+ const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100),
+ _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)};
+};
+
+// Handles q4_K and q5_K scales/mins
+struct Scales8K {
+ template <typename Q8>
+ inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {
+ make_q4_scales(data, utmp);
+ const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
+ const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1);
+ accum_mins(mins128, q8, i, c, accd);
+ const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
+ return MM256_SET_M128I(sc128, sc128);
+ }
+#ifdef HAVE_FANCY_SIMD
+ template <typename Q8>
+ inline __m512i process_mins_and_scales_64(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {
+ auto scales = process_mins_and_scales(data, c, i, q8, accd);
+ return _mm512_inserti32x8(_mm512_castsi256_si512(scales), scales, 1);
+ }
+#endif
+ template <typename Q8>
+ inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {
+ base.accum_mins(mins128, q8, i, c, accd);
+ }
+#ifdef HAVE_FANCY_SIMD
+ const __m512i shuffles512[2] = {
+ _mm512_set_epi64(0x0706070607060706, 0x0302030203020302, 0x0706070607060706, 0x0302030203020302,
+ 0x0504050405040504, 0x0100010001000100, 0x0504050405040504, 0x0100010001000100),
+ _mm512_set_epi64(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a,
+ 0x0d0c0d0c0d0c0d0c, 0x0908090809080908, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
+ };
+#endif
+ Scales8KBase base;
+
+ uint32_t utmp[4];
+};
+
+template <typename Q8>
+inline void process_mins_16(const __m256i& all_scales, const Q8& q8, int i, float d, __m256 * accm) {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i prod = _mm256_madd_epi16(all_scales, q8.load_bsums(iy, i));
+ accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
+ }
+}
+inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {
+ const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
+ const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
+ scales[0] = MM256_SET_M128I(l_scales, l_scales);
+ scales[1] = MM256_SET_M128I(h_scales, h_scales);
+}
+
+struct ScaleQ3 {
+ inline __m128i make_scales(const uint16_t * s8) const {
+ const uint16_t * scales16 = (const uint16_t *)s8;
+ uint32_t aux0 = scales16[0] | (scales16[1] << 16);
+ uint32_t aux1 = scales16[2] | (scales16[3] << 16);
+ uint32_t aux2 = scales16[4] | (scales16[5] << 16);
+ __m128i scales128 = _mm_set_epi32(
+ ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030),
+ ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030),
+ (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030),
+ (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030));
+ return _mm_add_epi8(scales128, m32);
+ }
+ const __m128i m32 = _mm_set1_epi8(-32);
+};
+
+struct ScaleIQ4XS {
+ inline __m128i make_scales(const uint32_t scales_l, const uint16_t scales_h) {
+ uint32_t tmp32 = scales_h | (scales_h << 14);
+ const __m128i sh = _mm_slli_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(tmp32), hshift), hmask), 4);
+ const __m128i sl = _mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(scales_l), lshift), lmask);
+ return _mm_add_epi16(_mm_or_si128(sh, _mm_cvtepi8_epi16(_mm_shuffle_epi8(sl, lshuffle))), m32);
+ }
+ const __m128i hshift = _mm_set_epi32(12, 8, 4, 0);
+ const __m128i lshift = _mm_set_epi32(4, 0, 4, 0);
+ const __m128i hmask = _mm_set1_epi16(0x03);
+ const __m128i lmask = _mm_set1_epi8(0xf);
+ const __m128i lshuffle = _mm_set_epi32(0x07030602, 0x05010400, 0x07030602, 0x05010400);
+ const __m128i m32 = _mm_set1_epi16(-32);
+};
+
+template <typename Block>
+struct BaseDequantizer {
+ BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {}
+ inline void new_row(int ix) {
+ x = (const Block *)((const char *)vx + bx*ix);
+ }
+
+ const void * vx;
+ const size_t bx;
+ const Block * x;
+
+ float d;
+};
+
+inline __m256i get_scale_shuffle_8(int i) {
+ return _mm256_set1_epi16((2*i) | ((2*i+1) << 8));
+}
+
+inline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) {
+ scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0));
+ scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1));
+ scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2));
+ scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3));
+}
+
+inline __m256i get_scale_shuffle_16(int i) {
+ static const uint8_t k_shuffle[128] = {
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
+ };
+ return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+
+inline void set_scales_16(const __m256i& all_scales, __m256i * scales) {
+ scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0));
+ scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1));
+ scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2));
+ scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));
+}
+
+template <typename Q8, typename Bits>
+inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
+ if (j == 0) {
+#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
+ }
+#else
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));
+ const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));
+ const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));
+ const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));
+ sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4));
+ }
+#endif
+ } else {
+#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
+ sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
+ }
+#else
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));
+ const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));
+ const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));
+ const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));
+ sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));
+ sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));
+ }
+#endif
+ }
+}
+
+struct SignHelper {
+ inline __m256i make_signs(uint32_t sign_bits) const {
+ auto aux256 = _mm256_set1_epi32(sign_bits);
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256, mask1), mask2);
+ return _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone);
+ }
+// inline __m256i make_signs(const uint16_t * sign_bits) const {
+//#ifdef HAVE_FANCY_SIMD
+//#else
+// return make_signs(sign_bits[0] | (sign_bits[1] << 16));
+//#endif
+// }
+ inline __m256i sign_value(const uint16_t * sign_bits, const __m256i& value) const {
+#ifdef HAVE_FANCY_SIMD
+ const __mmask32 * mask = (const __mmask32 *)sign_bits;
+ return _mm256_mask_sub_epi8(value, mask[0], _mm256_setzero_si256(), value);
+#else
+ return _mm256_sign_epi8(value, make_signs(sign_bits[0] | (sign_bits[1] << 16)));
+#endif
+ }
+ inline void sign_4_values(const uint16_t * sign_bits, __m256i * values) const {
+#ifdef HAVE_FANCY_SIMD
+ const __mmask32 * mask = (const __mmask32 *)sign_bits;
+ values[0] = _mm256_mask_sub_epi8(values[0], mask[0], _mm256_setzero_si256(), values[0]);
+ values[1] = _mm256_mask_sub_epi8(values[1], mask[1], _mm256_setzero_si256(), values[1]);
+ values[2] = _mm256_mask_sub_epi8(values[2], mask[2], _mm256_setzero_si256(), values[2]);
+ values[3] = _mm256_mask_sub_epi8(values[3], mask[3], _mm256_setzero_si256(), values[3]);
+#else
+ auto s128 = _mm_loadu_si128((const __m128i *)sign_bits);
+ auto s256 = MM256_SET_M128I(s128, s128);
+ __m256i aux256;
+ auto shuffle = mask1;
+ auto step = _mm256_set1_epi8(4);
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
+ values[0] = _mm256_sign_epi8(values[0], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
+ values[1] = _mm256_sign_epi8(values[1], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
+ values[2] = _mm256_sign_epi8(values[2], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
+ aux256 = _mm256_and_si256(_mm256_shuffle_epi8(s256, shuffle), mask2); shuffle = _mm256_add_epi8(shuffle, step);
+ values[3] = _mm256_sign_epi8(values[3], _mm256_or_si256(_mm256_cmpeq_epi8(aux256, mask2), mone));
+#endif
+ }
+ const __m256i mask1 = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
+ const __m256i mask2 = _mm256_set1_epi64x(0x8040201008040201ull);
+ const __m256i mone = _mm256_set1_epi8(1);
+};
+
+struct SimpleBits {
+ __m256i values[4];
+};
+
+#ifdef HAVE_FANCY_SIMD
+//====================================== Zen4 ==================================================
+
+struct BlockPermuter {
+ const __m512i permute1 = _mm512_set_epi64(11, 10, 9, 8, 3, 2, 1, 0);
+ const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);
+};
+
+struct Q4Bits {
+ inline void prepare(const uint8_t * q4) {
+ auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
+ auto tmp1 = _mm512_and_si512(q4bits, ml);
+ auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
+ values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
+ values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
+ q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
+ tmp1 = _mm512_and_si512(q4bits, ml);
+ tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
+ values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);
+ values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);
+ }
+ inline void prepare64(const uint8_t * q4) {
+ auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);
+ values[0] = _mm512_and_si512(q4bits, ml);
+ values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
+ q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);
+ values[2] = _mm512_and_si512(q4bits, ml);
+ values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);
+ }
+ __m512i values[4];
+ const __m512i ml = _mm512_set1_epi8(0xf);
+ BlockPermuter perm;
+};
+
+struct Q2Bits {
+ inline void prepare(const uint8_t * q2) {
+
+ auto q2bits = _mm512_loadu_si512((const __m512i*)q2);
+ auto tmp = _mm512_srli_epi16(q2bits, 2);
+
+ values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp);
+ values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp);
+ values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml);
+ values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml);
+ values[0] = _mm512_and_si512(values[0], ml);
+ values[2] = _mm512_and_si512(values[2], ml);
+ }
+ __m512i values[4];
+ const __m512i ml = _mm512_set1_epi8(0x03);
+ BlockPermuter perm;
+};
+
+struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
+ DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ bits.prepare(x[i].qs);
+ auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
+ scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
+ scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
+ }
+
+ Q4Bits bits;
+ Scales8K s8k;
+};
+
+struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
+ DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ prepare(x[i].qs);
+ auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);
+ s8k.accum_mins(scales128, q8, i, -128.f*d, accd);
+ auto scales256 = MM256_SET_M128I(scales128, scales128);
+ auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
+ scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
+ scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
+ }
+ static __m512i load_values() {
+ static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};
+ auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
+ auto val256 = MM256_SET_M128I(val128, val128);
+ return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);
+ }
+ inline void prepare(const uint8_t * q4) {
+ bits.prepare64(q4);
+ // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111
+ // bits.valuse[1]: 16..31, 48...63, 80...95, 112..127
+ // etc.
+ auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
+ bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));
+ bits.values[0] = _mm512_shuffle_epi8(values, tmp);
+ tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
+ bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));
+ bits.values[2] = _mm512_shuffle_epi8(values, tmp);
+ }
+
+ Q4Bits bits;
+ Scales8K s8k;
+ ScaleIQ4XS siq4;
+ const __m512i values;
+ const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
+ const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
+};
+
+struct HighBit5 {
+ inline void apply(const uint8_t * h, Q4Bits& bits) {
+ auto hbits256 = _mm256_loadu_si256((const __m256i *)h);
+ auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);
+ bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh));
+ bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));
+ bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(hbits, mh));
+ bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));
+ }
+ const __m512i mh = _mm512_set1_epi8(0x10);
+};
+
+struct HighBit3 {
+ inline void apply(const uint8_t * h, Q2Bits& bits) {
+ auto hbits256 = _mm256_loadu_si256((const __m256i *)h);
+ auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);
+ bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));
+ bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, mh));
+ bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));
+ bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), mh));
+ }
+ const __m512i mh = _mm512_set1_epi8(0x04);
+};
+
+struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
+ DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ bits.prepare(x[i].qs);
+ hbits.apply(x[i].qh, bits);
+ auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
+ scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
+ scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
+ }
+
+ Q4Bits bits;
+ HighBit5 hbits;
+ Scales8K s8k;
+};
+
+struct Scale16 {
+ inline void make_scales(const __m128i& scales8, __m512i * scales) const {
+ auto all_scales8 = MM256_SET_M128I(scales8, scales8);
+ auto scales1 = _mm256_shuffle_epi8(all_scales8, shuffle1);
+ auto scales2 = _mm256_shuffle_epi8(all_scales8, shuffle2);
+ scales[0] = _mm512_cvtepi8_epi16(scales1);
+ scales[1] = _mm512_cvtepi8_epi16(scales2);
+ }
+ template <typename Q8>
+ inline void process_mins_and_scales(int i, float c, const __m128i& mins8, const __m128i& scales8,
+ const Q8& q8, __m256 * accm, __m512i * scales) const {
+ process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, c, accm);
+ make_scales(scales8, scales);
+ }
+ const __m256i shuffle1 = _mm256_set_epi32(0x07070707, 0x03030303, 0x06060606, 0x02020202,
+ 0x05050505, 0x01010101, 0x04040404, 0x00000000);
+ const __m256i shuffle2 = _mm256_set_epi32(0x0f0f0f0f, 0x0b0b0b0b, 0x0e0e0e0e, 0x0a0a0a0a,
+ 0x0d0d0d0d, 0x09090909, 0x0c0c0c0c, 0x08080808);
+};
+
+struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
+ DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ bits.prepare(x[i].qs);
+ const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+ const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
+ const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
+ sc16.process_mins_and_scales(i, -GGML_FP16_TO_FP32(x[i].dmin), mins8, scales8, q8, accm, scales);
+ }
+
+ Q2Bits bits;
+ Scale16 sc16;
+ const __m128i m4 = _mm_set1_epi8(0xf);
+
+};
+
+struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
+ DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ bits.prepare(x[i].qs);
+ hbits.apply(x[i].hmask, bits);
+ auto scales128 = sc3.make_scales((const uint16_t *)x[i].scales);
+ sc16.process_mins_and_scales(i, -4.f*d, scales128, scales128, q8, accm, scales);
+ }
+
+ Q2Bits bits;
+ HighBit3 hbits;
+ ScaleQ3 sc3;
+ Scale16 sc16;
+ const __m128i m4 = _mm_set1_epi8(0xf);
+ const __m128i m32 = _mm_set1_epi8(-32);
+};
+
+struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
+ DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ bits.prepare64(x[i].ql);
+ add_high_bits(x[i].qh, bits);
+ auto scales128 = _mm_loadu_si128((const __m128i *)x[i].scales);
+ sc16.process_mins_and_scales(i, -32.f*d, scales128, scales128, q8, accm, scales);
+ }
+
+ inline void add_high_bits(const uint8_t * qh, Q4Bits& bits) const {
+ auto hbits = _mm512_loadu_si512((const __m512i *)qh);
+ auto tmp1 = _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh);
+ auto tmp2 = _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh);
+ bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));
+ bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));
+ tmp1 = _mm512_and_si512(hbits, mh);
+ tmp2 = _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh);
+ bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));
+ bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));
+ }
+
+ Q4Bits bits;
+ HighBit3 hbits;
+ Scale16 sc16;
+
+ const __m512i mh = _mm512_set1_epi8(0x30);
+
+};
+
+template <typename Q8>
+inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {
+ const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));
+ const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1));
+ const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2));
+ const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3));
+ auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
+ sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
+ accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
+}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y> q8(info);
+
+ Dequantizer deq(vx, bx);
+
+ __m256 accm[nrc_y];
+ __m512 accd[nrc_y];
+ __m512i scales[2];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
+ for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ deq.new_block(i, q8, accm, scales);
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ //compute_block(iy, i, deq.d, q8, deq.bits.values, scales, accd);
+ const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
+ const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
+ const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
+ const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));
+ auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
+ sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
+ accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
+ info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
+ }
+
+ }
+}
+
+template <typename Dequantizer>
+static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ constexpr int k_nx = 2;
+
+ Q8<1> q8(info);
+
+ Dequantizer deq1(vx, bx);
+ Dequantizer deq2(vx, bx);
+
+ Dequantizer * deq[k_nx];
+ deq[0] = &deq1;
+ deq[1] = &deq2;
+
+ __m512i scales[2*k_nx];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ auto accd = _mm512_setzero_ps();
+ auto accm = _mm256_setzero_ps();
+
+ for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_row(ix);
+
+ for (int i = 0; i < nb/k_nx; ++i) {
+
+ for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx);
+
+ for (int kx = 0; kx < k_nx; ++kx) {
+ compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
+ }
+
+ }
+ if (2*(nb/2) < nb) {
+ int i0 = 2*(nb/2);
+ deq[0]->new_block(i0, q8, &accm, scales);
+ compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
+ }
+
+ auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
+ info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
+ }
+}
+
+#else
+// ===================================== Vanilla AVX2 =====================================
+
+struct Q4Bits {
+ inline void prepare(const uint8_t * q4, int j) {
+ auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
+ values[0] = _mm256_and_si256(q4bits, ml);
+ values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
+ q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
+ values[2] = _mm256_and_si256(q4bits, ml);
+ values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
+ }
+ inline void prepare64(const uint8_t * q4, int j) {
+ auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
+ values[0] = _mm256_and_si256(q4bits, ml);
+ values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
+ q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
+ values[1] = _mm256_and_si256(q4bits, ml);
+ values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
+ }
+ inline void prepare16(const uint8_t * q4, int j) {
+ values[0] = dequant16(q4 + 64*j + 0);
+ values[1] = dequant16(q4 + 64*j + 16);
+ values[2] = dequant16(q4 + 64*j + 32);
+ values[3] = dequant16(q4 + 64*j + 48);
+ }
+ inline __m256i dequant16(const uint8_t * qs) const {
+ const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);
+ const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128);
+ return _mm256_and_si256(ml, aux256);
+ }
+ __m256i values[4];
+ const __m256i ml = _mm256_set1_epi8(0xf);
+};
+
+struct Q2Bits {
+ inline void prepare(const uint8_t * q2, int j) {
+ auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j);
+ values[0] = _mm256_and_si256(q2bits, ml);
+ values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
+ values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
+ values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
+ }
+ __m256i values[4];
+ const __m256i ml = _mm256_set1_epi8(0x03);
+};
+
+struct HighBit5 {
+ inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }
+ inline void apply(Q4Bits& bits, bool do_shift) {
+ bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
+ bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh));
+ bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
+ bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
+ if (do_shift) {
+ hbits = _mm256_srli_epi16(hbits, 4);
+ }
+ }
+ const __m256i mh = _mm256_set1_epi8(0x10);
+ __m256i hbits;
+};
+
+struct HighBit3 {
+ inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }
+ inline void apply(Q2Bits& bits, bool do_shift) {
+ bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
+ bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
+ bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));
+ bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh));
+ if (do_shift) {
+ hbits = _mm256_srli_epi16(hbits, 4);
+ }
+ }
+ const __m256i mh = _mm256_set1_epi8(0x04);
+ __m256i hbits;
+};
+
+struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
+ DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ }
+
+ Q4Bits bits;
+ Scales8K s8k;
+};
+
+struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
+ DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}
+ template <typename Q8>
+ inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);
+ s8k.accum_mins(scales128, q8, i, -128.f*d, accd);
+ return MM256_SET_M128I(scales128, scales128);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare16(x[i].qs, j);
+ bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);
+ bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);
+ bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);
+ bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);
+ }
+
+ static __m256i load_values() {
+ static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};
+ auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
+ return MM256_SET_M128I(val128, val128);
+ }
+
+ Q4Bits bits;
+ Scales8K s8k;
+ ScaleIQ4XS siq4;
+ const __m256i values;
+};
+
+struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
+ DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ hbits.load(x[i].qh);
+ return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ hbits.apply(bits, j == 0);
+ }
+
+ Q4Bits bits;
+ HighBit5 hbits;
+ Scales8K s8k;
+};
+
+template <typename Q8>
+inline void process_mins_and_scales_16(const __m128i& scales128, const Q8& q8, int i, float d,
+ __m256 * accm, __m256i * scales) {
+ const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
+ process_mins_16(all_scales, q8, i, d, accm);
+ prepare_scales_16(all_scales, scales);
+}
+
+struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
+ DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ hbits.load(x[i].hmask);
+ process_mins_and_scales_16(sc3.make_scales((const uint16_t *)x[i].scales), q8, i, -4.f*d, accm, scales);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ hbits.apply(bits, j == 0);
+ }
+
+ Q2Bits bits;
+ HighBit3 hbits;
+ ScaleQ3 sc3;
+
+ const __m128i m32 = _mm_set1_epi8(-32);
+};
+
+struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
+ DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+ const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
+ const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
+ process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, -GGML_FP16_TO_FP32(x[i].dmin), accm);
+ prepare_scales_16(_mm256_cvtepi8_epi16(scales8), scales);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs, j);
+ }
+
+ Q2Bits bits;
+
+ const __m128i m4 = _mm_set1_epi8(0xf);
+};
+
+struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
+ DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+ template <typename Q8>
+ inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ process_mins_and_scales_16(_mm_loadu_si128((const __m128i *)x[i].scales), q8, i, -32.f*d, accm, scales);
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare64(x[i].ql, j);
+ auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j);
+ bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
+ bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
+ bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));
+ bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh));
+ }
+
+ Q4Bits bits;
+ const __m256i mh = _mm256_set1_epi8(0x30);
+};
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK_K == 0);
+ const int nb = n/QK_K;
+
+ Q8<nrc_y> q8(info);
+
+ __m256i all_scales[2];
+ __m256i scales[4];
+ __m256 accd[nrc_y];
+
+ Dequantizer deq(vx, bx);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ deq.new_row(ix);
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ for (int i = 0; i < nb; ++i) {
+
+ deq.new_block(i, q8, accd, all_scales);
+
+ __m256i sumi[nrc_y];
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ deq.prepare(i, j);
+ set_scales_16(all_scales[j], scales);
+ multiply_add(deq.bits, scales, j, i, q8, sumi);
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+
+ }
+
+}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y> q8(info);
+
+ Dequantizer deq(vx, bx);
+
+ __m256 accd[nrc_y];
+ __m256i scales[4];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ auto all_scales = deq.new_block(i, q8, accd);
+
+ __m256i sumi[nrc_y];
+
+ for (int j = 0; j < QK_K/128; ++j) {
+
+ deq.prepare(i, j);
+
+ set_scales_8(all_scales, j, scales);
+
+ multiply_add(deq.bits, scales, j, i, q8, sumi);
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
+ accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
+ }
+
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+
+ }
+}
+
+#endif // Zen4 or vanilla AVX2
+
+template <typename Bits>
+inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) {
+ if (j == 0) {
+#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
+ auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);
+ auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);
+ auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);
+ auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);
+ sumi[0] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_packs_epi32(p1, p2));
+ sumi[1] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[1], _mm256_packs_epi32(p3, p4));
+#else
+ const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));
+ const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));
+ const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));
+ const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));
+ sumi[0] = _mm256_add_epi32(p1, p3);
+ sumi[1] = _mm256_add_epi32(p2, p4);
+#endif
+ } else {
+#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
+ auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);
+ auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);
+ auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);
+ auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);
+ sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[0], _mm256_packs_epi32(p1, p2));
+ sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[1], _mm256_packs_epi32(p3, p4));
+#else
+ const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));
+ const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));
+ const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));
+ const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));
+ sumi[0] = _mm256_add_epi32(sumi[0], _mm256_add_epi32(p1, p3));
+ sumi[1] = _mm256_add_epi32(sumi[1], _mm256_add_epi32(p2, p4));
+#endif
+ }
+}
+
+inline void set_scales_8_iq(int j, const __m256i& all_scales, __m256i * scales) {
+#ifdef HAVE_FANCY_SIMD
+ auto shuffle = j == 0 ? _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100)
+ : _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908);
+ scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);
+ scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(4)));
+#else
+ set_scales_8(all_scales, j, scales);
+#endif
+}
+
+inline void set_scales_16_iq(const __m256i& all_scales, __m256i * scales) {
+#ifdef HAVE_FANCY_SIMD
+ auto shuffle = _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100);
+ scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);
+ scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(8)));
+#else
+ set_scales_16(all_scales, scales);
+#endif
+}
+
+template <typename Dequantizer>
+static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK_K;
+ Q8<1> q8(info);
+ Dequantizer deq(vx, bx);
+ __m256i scales[2];
+ __m256i q8_quants[4];
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ __m256 accd = _mm256_setzero_ps();
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ __m256i sumi[2], all_scales[Dequantizer::num_blocks/8];
+ deq.new_block(i, all_scales);
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ deq.prepare(i, j, q8, q8_quants);
+ if constexpr (Dequantizer::num_blocks == 8) {
+ set_scales_8_iq(j, all_scales[0], scales);
+ } else {
+ set_scales_16_iq(all_scales[j], scales);
+ }
+ multiply_add_1(j, deq.bits, scales, q8_quants, sumi);
+ }
+ accd = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(0, i)), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi[0], sumi[1])), accd);
+ }
+
+ info.store(ix, 0, hsum_float_8(accd));
+ }
+}
+
+// So, if I uncomment this function and the call to it in mul_mat_qX_K_q8_K_IQ_N() below,
+// PP performance improves by ~2-3% (when we have __AVX512VNNI__ and __AVX512VL__).
+// But TG performance for iq3_xs drops by 35%. Seriously? I mean, c'mon,
+// what does the compilation of mul_mat_qX_K_q8_K_IQ_1 (which gets invoked during TG)
+// have to do with the compilation of mul_mat_qX_K_q8_K_IQ_N (invoked during PP)?
+//template <typename Q8, typename Bits>
+//inline void multiply_add_iq(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
+//#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
+// for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4*j+0)));
+// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 4*j+1)));
+// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 4*j+2)));
+// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 4*j+3)));
+// }
+//#else
+// for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+// const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4*j+0)));
+// const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 4*j+1)));
+// const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 4*j+2)));
+// const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 4*j+3)));
+// sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));
+// sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));
+// }
+//#endif
+//}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK_K;
+ Q8<nrc_y> q8(info);
+ Dequantizer deq(vx, bx);
+ __m256i scales[4];
+ __m256 accd[nrc_y];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
+
+ deq.new_row(ix);
+
+ for (int i = 0; i < nb; ++i) {
+
+ __m256i sumi[nrc_y], all_scales[Dequantizer::num_blocks/8];
+ //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();
+ __m256i mins;
+ float dmin = deq.new_block(i, all_scales, mins);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto bsums = q8.load_bsums(iy, i);
+ auto prod = _mm256_madd_epi16(mins, bsums);
+ accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
+ }
+
+ for (int j = 0; j < QK_K/128; ++j) {
+ deq.prepare(i, j);
+ if constexpr (Dequantizer::num_blocks == 8) {
+ set_scales_8(all_scales[0], j, scales);
+ } else {
+ set_scales_16(all_scales[j], scales);
+ }
+ //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);
+ multiply_add(deq.bits, scales, j, i, q8, sumi);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
+ accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, hsum_float_8(accd[iy]));
+ }
+ }
+}
+
+template <int nrc> struct Q8_K64 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8_K64(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ const float * dptr = (const float *)info.src1_row(iy);
+ std::memcpy(d + 4*iy, dptr, 4*sizeof(float));
+ y[iy] = (const int8_t *)(dptr + 4);
+ }
+ }
+
+ inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); }
+ inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 4*iy); }
+
+ float d[4*nrc_y];
+ const int8_t * y[nrc_y];
+};
+
+struct DequantizerIQ1BN {
+ const __m256i m1_8 = _mm256_set1_epi8(1);
+ static __m256i load_shuffle(int i) {
+ static const uint8_t data[128] = {
+ 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 1, 255, 1, 255, 1, 255, 1, 255, 1, 255, 2, 255, 2, 255, 2, 255, 2, 255, 2, 255, 12, 255,
+ 3, 255, 3, 255, 3, 255, 3, 255, 3, 255, 4, 255, 4, 255, 4, 255, 4, 255, 4, 255, 5, 255, 5, 255, 5, 255, 5, 255, 5, 255, 12, 255,
+ 6, 255, 6, 255, 6, 255, 6, 255, 6, 255, 7, 255, 7, 255, 7, 255, 7, 255, 7, 255, 8, 255, 8, 255, 8, 255, 8, 255, 8, 255, 12, 255,
+ 9, 255, 9, 255, 9, 255, 9, 255, 9, 255, 10, 255, 10, 255, 10, 255, 10, 255, 10, 255, 11, 255, 11, 255, 11, 255, 11, 255, 11, 255, 12, 255,
+ };
+ return _mm256_loadu_si256((const __m256i*)data + i);
+ }
+ const __m256i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) };
+ const __m256i mult[4] = {
+ _mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
+ _mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
+ _mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
+ _mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
+ };
+ const __m256i m3 = _mm256_set1_epi16(3);
+#ifdef HAVE_FANCY_SIMD
+ const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
+#endif
+
+ IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const {
+ auto data128 = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes!
+ auto data = MM256_SET_M128I(data128, data128);
+ auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[0]), mult[0]), m3);
+ auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[1]), mult[1]), m3);
+ auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3);
+ auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3);
+#ifdef HAVE_FANCY_SIMD
+ v1 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val1, bmask, val2), m1_8);
+ v2 = _mm256_sub_epi8(_mm256_permutex2var_epi8(val3, bmask, val4), m1_8);
+#else
+ v1 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216), m1_8);
+ v2 = _mm256_sub_epi8(_mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216), m1_8);
+#endif
+ }
+
+};
+
+template <int nrc_y>
+IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK_IQ1BN;
+ Q8_K64<nrc_y> q8(info);
+ DequantizerIQ1BN deq;
+ __m256i accd[nrc_y];
+ __m256i val[4];
+
+#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
+ const auto m1_16 = _mm256_set1_epi16(1);
+#endif
+
+ const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ x = (const block_iq1_bn *)((const char *)vx + ix*bx);
+
+ if constexpr (nrc_y == 1) {
+ __m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256();
+ for (int i = 0; i < nb/2; ++i) {
+ deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
+ deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
+#if defined __AVX512VNNI__ && defined __AVX512VL__
+ auto dot1 = _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0]);
+ auto dot2 = _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]);
+ auto dot3 = _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2]);
+ auto dot4 = _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]);
+ acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, deq.m1_8, dot1), deq.m1_8, dot2);
+ acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, deq.m1_8, dot3), deq.m1_8, dot4);
+#else
+ auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])),
+ _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1])));
+ auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])),
+ _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3])));
+ acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1));
+ acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2));
+#endif
+ }
+ accd[0] = _mm256_add_epi32(acc1, acc2);
+ }
+ else {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
+
+ for (int i = 0; i < nb/2; ++i) {
+
+ deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
+ deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+#if defined __AVX512VNNI__ && defined __AVX512VL__
+ auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]);
+ auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]);
+ auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]);
+ auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]);
+ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
+ accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
+#else
+ auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0])),
+ _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1])));
+ auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2])),
+ _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3])));
+ dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
+ accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
+#endif
+ }
+ }
+ }
+ int i = 2*(nb/2);
+ if (i < nb) {
+ deq.prepare_iq1bn_quants(x + i, val[0], val[1]);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
+ auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);
+#if defined __AVX512VNNI__ && defined __AVX512VL__
+ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2);
+#else
+ auto dot = _mm256_madd_epi16(m1_16,
+ _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)));
+ accd[iy] = _mm256_add_epi32(dot, accd[iy]);
+#endif
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto vd = q8.scale(iy);
+ auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
+ auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi));
+ info.store(ix, iy, hsum_float_4(sumf));
+ }
+
+ }
+}
+
+struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn> {
+ DequantizeIQ2BN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ IQK_ALWAYS_INLINE void prepare4(int i, __m256i * val) const {
+ auto q2bits_1 = _mm256_loadu_si256((const __m256i *)x[2*i].qs);
+ auto q2bits_2 = _mm256_srli_epi16(q2bits_1, 2);
+ make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x20), val+0);
+ make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2);
+ }
+ IQK_ALWAYS_INLINE void make2(__m256i q2_1, __m256i * val) const {
+ val[0] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask2), m1_8);
+ val[1] = _mm256_sub_epi8(_mm256_and_si256(q2_1, mask3), mf_8);
+ }
+ IQK_ALWAYS_INLINE void prepare2(int i, __m256i * val) const {
+ auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
+ make2(MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1), val);
+ }
+ const __m256i m1_8 = _mm256_set1_epi8(1);
+ const __m256i mf_8 = _mm256_set1_epi8(16);
+ const __m256i mask2 = _mm256_set1_epi8(0x03);
+ const __m256i mask3 = _mm256_set1_epi8(0x30);
+};
+
+template <int nrc_y>
+IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK_IQ1BN;
+ Q8_K64<nrc_y> q8(info);
+ DequantizeIQ2BN deq(vx, bx);
+ __m256i accd[nrc_y];
+ __m256i val[4];
+
+#if !(defined __AVX512VNNI__ && defined __AVX512VL__)
+ const auto m1_16 = _mm256_set1_epi16(1);
+#endif
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ deq.new_row(ix);
+
+ if constexpr (nrc_y == 1) {
+ __m256i acc[2] = {};
+ for (int i = 0; i < nb/2; ++i) {
+ deq.prepare4(i, val);
+#if defined __AVX512VNNI__ && defined __AVX512VL__
+ acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])),
+ deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1]));
+ acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])),
+ deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3]));
+#else
+ auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 0), val[0])),
+ _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 1), val[1])));
+ auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 2), val[2])),
+ _mm256_maddubs_epi16(deq.m1_8, _mm256_sign_epi8(q8.load_quants(0, i, 3), val[3])));
+ acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1));
+ acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2));
+#endif
+ }
+ accd[0] = _mm256_add_epi32(acc[0], acc[1]);
+ }
+ else {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
+
+ for (int i = 0; i < nb/2; ++i) {
+ deq.prepare4(i, val);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), val[0]);
+ auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), val[1]);
+ auto dot3 = _mm256_sign_epi8(q8.load_quants(iy, i, 2), val[2]);
+ auto dot4 = _mm256_sign_epi8(q8.load_quants(iy, i, 3), val[3]);
+#if defined __AVX512VNNI__ && defined __AVX512VL__
+ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(
+ accd[iy], deq.m1_8, dot1), deq.m1_8, dot2), deq.m1_8, dot3), deq.m1_8, dot4);
+#else
+ auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
+ _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)),
+ _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot3), _mm256_maddubs_epi16(deq.m1_8, dot4))));
+ accd[iy] = _mm256_add_epi32(dot, accd[iy]);
+#endif
+ }
+ }
+ }
+ int i = 2*(nb/2);
+ if (i < nb) {
+ deq.prepare2(i, val);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 0), val[0]);
+ auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i/2, 1), val[1]);
+#if defined __AVX512VNNI__ && defined __AVX512VL__
+ accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], deq.m1_8, dot1), deq.m1_8, dot2);
+#else
+ dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(deq.m1_8, dot1), _mm256_maddubs_epi16(deq.m1_8, dot2)));
+ accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
+#endif
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto vd = q8.scale(iy);
+ auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
+ auto sumf = _mm_mul_ps(vd, _mm_cvtepi32_ps(sumi));
+ info.store(ix, iy, hsum_float_4(sumf));
+ }
+ }
+}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ if constexpr (nrc_y == 1) {
+ mul_mat_qX_K_q8_K_IQ_1<Dequantizer>(n, vx, bx, info, nrc_x);
+ } else {
+ mul_mat_qX_K_q8_K_IQ_N<Dequantizer, nrc_y>(n, vx, bx, info, nrc_x);
+ }
+}
+
+//#ifdef HAVE_FANCY_SIMD
+// Strangely enough, the following implementation makes PP ~6% slower and TG ~6% faster
+// compared to the vanilla AVX2 version below.
+//struct IndexHelperIQ3S {
+// union index_t {
+// __m256i vec;
+// uint16_t val[16];
+// };
+// inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const {
+// auto idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs));
+// const __mmask16 * m16 = (const __mmask16 *)qh;
+// index_t idx;
+// idx.vec = _mm256_mask_add_epi16(idx_l, m16[0], idx_l, offset);
+// values[0] = _mm256_set_epi32(iq3s_grid[idx.val[ 7]], iq3s_grid[idx.val[ 6]], iq3s_grid[idx.val[ 5]], iq3s_grid[idx.val[ 4]],
+// iq3s_grid[idx.val[ 3]], iq3s_grid[idx.val[ 2]], iq3s_grid[idx.val[ 1]], iq3s_grid[idx.val[ 0]]);
+// values[1] = _mm256_set_epi32(iq3s_grid[idx.val[15]], iq3s_grid[idx.val[14]], iq3s_grid[idx.val[13]], iq3s_grid[idx.val[12]],
+// iq3s_grid[idx.val[11]], iq3s_grid[idx.val[10]], iq3s_grid[idx.val[ 9]], iq3s_grid[idx.val[ 8]]);
+// }
+// const __m256i offset = _mm256_set1_epi16(256);
+//};
+//#else
+struct IndexHelperIQ3S {
+ union index_t {
+ __m256i vec;
+ uint32_t val[8];
+ };
+ inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const {
+ index_t idx;
+ auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs));
+ auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask);
+ idx.vec = _mm256_or_si256(idx_h, idx_l);
+ values[0] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]],
+ iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]);
+ idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs+8)));
+ idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask);
+ idx.vec = _mm256_or_si256(idx_h, idx_l);
+ values[1] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]],
+ iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]);
+ }
+ const __m256i idx_mask = _mm256_set1_epi32(256);
+ const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
+};
+//#endif
+
+struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
+ DequantizerIQ3S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ constexpr static int num_blocks = 8;
+
+ inline __m128i make_scales(int i, float& dd) const {
+ dd = GGML_FP16_TO_FP32(x[i].d);
+ uint32_t aux32[2];
+ std::memcpy(aux32, x[i].scales, 4);
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+ aux32[0] &= 0x0f0f0f0f;
+ auto scales8 = _mm_shuffle_epi8(_mm_loadl_epi64((const __m128i *)aux32), _mm_set1_epi64x(0x0703060205010400));
+ auto scales16 = _mm256_castsi256_si128(_mm256_cvtepi8_epi16(scales8));
+ return _mm_or_si128(_mm_slli_epi16(scales16, 1), _mm_set1_epi16(1));
+ }
+ inline void new_block(int i, __m256i * scales) {
+ auto scales16 = make_scales(i, d);
+ scales[0] = MM256_SET_M128I(scales16, scales16);
+ }
+ inline float new_block(int i, __m256i * scales, __m256i& mins) {
+ auto scales16 = make_scales(i, d);
+ mins = scb.shuffle(scales16);
+ scales[0] = MM256_SET_M128I(scales16, scales16);
+ return -minv*d;
+ }
+
+ inline void prepare(int i, int j) {
+ prepare_unsigned(i, j);
+ sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, bits.values);
+ for (int k = 0; k < 4; ++k) bits.values[k] = _mm256_add_epi8(bits.values[k], min_value);
+ }
+ inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
+ prepare_unsigned(i, j);
+ for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
+ sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants);
+ }
+
+ inline void prepare_unsigned(int i, int j) {
+ auto qs = x[i].qs + 32*j;
+ auto qh = x[i].qh + 4*j;
+ helper.make2(qs+ 0, qh+0, bits.values+0);
+ helper.make2(qs+16, qh+2, bits.values+2);
+ }
+
+ constexpr static int minv = 16;
+
+ SimpleBits bits;
+ SignHelper sh;
+ Scales8KBase scb;
+ IndexHelperIQ3S helper;
+ const __m256i min_value = _mm256_set1_epi8(minv);
+
+};
+
+struct EvenSignHelper {
+#ifdef HAVE_FANCY_SIMD
+ union sbits_t {
+ __m128i vec;
+ __mmask32 mask[4];
+ };
+ IQK_ALWAYS_INLINE void sign_2_values(__m256i aux, __m256i * values) const {
+ aux = _mm256_and_si256(_mm256_srlv_epi32(aux, shifts), mask);
+ auto pcnt = _mm256_popcnt_epi32(aux);
+ sbits_t sbits;
+ sbits.vec = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7)));
+ values[0] = _mm256_mask_sub_epi8(values[0], sbits.mask[0], _mm256_setzero_si256(), values[0]);
+ values[1] = _mm256_mask_sub_epi8(values[1], sbits.mask[1], _mm256_setzero_si256(), values[1]);
+ //auto sign_bits = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7)));
+ //const __mmask32 * m32 = (const __mmask32 *)&sign_bits;
+ //values[0] = _mm256_mask_sub_epi8(values[0], m32[0], _mm256_setzero_si256(), values[0]);
+ //values[1] = _mm256_mask_sub_epi8(values[1], m32[1], _mm256_setzero_si256(), values[1]);
+ }
+ const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0);
+ const __m256i mask = _mm256_set1_epi32(127);
+ const __m256i mone = _mm256_set1_epi32(1);
+#else
+ inline void sign_value(uint32_t aux32, __m256i& value) const {
+ auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127],
+ keven_signs[(aux32 >> 7) & 127], keven_signs[(aux32 >> 0) & 127]);
+ value = _mm256_sign_epi8(value, signs);
+ }
+#endif
+};
+
+struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
+ DequantizerIQ3XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ constexpr static int num_blocks = 8;
+
+ inline __m128i prepare_scales(int i) {
+ d = 0.25f * GGML_FP16_TO_FP32(x[i].d);
+ auto tmp = _mm256_loadu_si256((const __m256i *)(x[i].qs + QK_K/4));
+ auto scales32 = _mm256_srli_epi32(tmp, 28);
+ scales32 = _mm256_or_si256(_mm256_slli_epi32(scales32, 1), _mm256_set1_epi32(1));
+ return _mm_packs_epi32(_mm256_castsi256_si128(scales32), _mm256_extractf128_si256(scales32, 1));
+ }
+
+ inline void new_block(int i, __m256i * scales) {
+ auto scales16 = prepare_scales(i);
+ scales[0] = MM256_SET_M128I(scales16, scales16);
+ }
+ inline float new_block(int i, __m256i * scales, __m256i& mins) {
+ auto scales16 = prepare_scales(i);
+ mins = scb.shuffle(scales16);
+ scales[0] = MM256_SET_M128I(scales16, scales16);
+ return -d*minv;
+ }
+
+ inline static __m256i make_quants(const uint8_t * qs) {
+ return _mm256_set_epi32(iq3xxs_grid[qs[7]], iq3xxs_grid[qs[6]], iq3xxs_grid[qs[5]], iq3xxs_grid[qs[4]],
+ iq3xxs_grid[qs[3]], iq3xxs_grid[qs[2]], iq3xxs_grid[qs[1]], iq3xxs_grid[qs[0]]);
+ }
+ inline static void make4_unsigned(const uint8_t * qs, __m256i * values) {
+ values[0] = make_quants(qs+ 0);
+ values[1] = make_quants(qs+ 8);
+ values[2] = make_quants(qs+16);
+ values[3] = make_quants(qs+24);
+ }
+
+ IQK_ALWAYS_INLINE void sign_2_values(const uint16_t * signs, __m256i * values) const {
+#ifdef HAVE_FANCY_SIMD
+ esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(signs[2] | (signs[3] << 16)), _mm_set1_epi32(signs[0] | (signs[1] << 16))), values);
+#else
+ esh.sign_value(signs[0] | (signs[1] << 16), values[0]);
+ esh.sign_value(signs[2] | (signs[3] << 16), values[1]);
+#endif
+ }
+
+ inline void prepare(int i, int j) {
+ auto qs = x[i].qs + 32*j;
+ const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j;
+ make4_unsigned(qs, bits.values);
+ sign_2_values(signs+0, bits.values+0);
+ sign_2_values(signs+4, bits.values+2);
+ for (int k = 0; k < 4; ++k) bits.values[k] = _mm256_add_epi32(bits.values[k], min_value);
+ }
+ inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
+ for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
+ auto qs = x[i].qs + 32*j;
+ const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j;
+ make4_unsigned(qs, bits.values);
+ sign_2_values(signs+0, q8_quants+0);
+ sign_2_values(signs+4, q8_quants+2);
+ }
+
+ constexpr static int minv = 64;
+
+ SimpleBits bits;
+ Scales8KBase scb;
+ EvenSignHelper esh;
+ const __m256i min_value = _mm256_set1_epi8(minv);
+
+};
+
+struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
+ DequantizerIQ2S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ constexpr static int num_blocks = 16;
+
+ inline __m256i load_scales(int i) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+ auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
+ auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf));
+ auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
+ return _mm256_cvtepi8_epi16(scales8);
+ }
+ inline static void prepare_scales(const __m256i& all, __m256i * scales) {
+ auto scales_l = _mm256_castsi256_si128(all);
+ auto scales_h = _mm256_extractf128_si256(all, 1);
+ scales[0] = MM256_SET_M128I(scales_l, scales_l);
+ scales[1] = MM256_SET_M128I(scales_h, scales_h);
+ }
+
+ inline void new_block(int i, __m256i * scales) {
+ prepare_scales(load_scales(i), scales);
+ }
+ inline float new_block(int i, __m256i * scales, __m256i& mins) {
+ mins = load_scales(i);
+ prepare_scales(mins, scales);
+ return -d*minv;
+ }
+
+ union index_t {
+ __m256i vec;
+ uint32_t val[8];
+ };
+
+ inline static void make2(const uint8_t * qs, const uint8_t * qh, const __m256i& idx_shift, const __m256i& idx_mask, __m256i * values) {
+ auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs));
+ auto idx_h = MM256_SET_M128I(_mm_set1_epi32(qh[1]), _mm_set1_epi32(qh[0]));
+ index_t idx;
+ idx.vec = _mm256_or_si256(idx_l, _mm256_and_si256(_mm256_sllv_epi32(idx_h, idx_shift), idx_mask));
+ values[0] = _mm256_set_epi64x(iq2s_grid[idx.val[3]], iq2s_grid[idx.val[2]], iq2s_grid[idx.val[1]], iq2s_grid[idx.val[0]]);
+ values[1] = _mm256_set_epi64x(iq2s_grid[idx.val[7]], iq2s_grid[idx.val[6]], iq2s_grid[idx.val[5]], iq2s_grid[idx.val[4]]);
+ }
+ inline static void make2_signed(const SignHelper& sh, const uint8_t * qs, const uint8_t * qh, const uint16_t * sidx,
+ const __m256i& idx_shift, const __m256i& idx_mask, const __m256i& min_value, __m256i * values) {
+ make2(qs, qh, idx_shift, idx_mask, values);
+ values[0] = _mm256_add_epi8(sh.sign_value(sidx+0, values[0]), min_value);
+ values[1] = _mm256_add_epi8(sh.sign_value(sidx+2, values[1]), min_value);
+ }
+
+ inline void prepare(int i, int j) {
+ auto qs = x[i].qs + 16*j;
+ auto qh = x[i].qh + 4*j;
+ const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/8) + 8*j;
+ make2_signed(sh, qs+0, qh+0, signs+0, idx_shift, idx_mask, min_value, bits.values+0);
+ make2_signed(sh, qs+8, qh+2, signs+4, idx_shift, idx_mask, min_value, bits.values+2);
+ }
+ inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
+ auto qs = x[i].qs + 16*j;
+ auto qh = x[i].qh + 4*j;
+ const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/8) + 8*j;
+ make2(qs+0, qh+0, idx_shift, idx_mask, bits.values+0);
+ make2(qs+8, qh+2, idx_shift, idx_mask, bits.values+2);
+ q8_quants[0] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+0), sh.make_signs(signs[0] | (signs[1] << 16)));
+ q8_quants[1] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+1), sh.make_signs(signs[2] | (signs[3] << 16)));
+ q8_quants[2] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+2), sh.make_signs(signs[4] | (signs[5] << 16)));
+ q8_quants[3] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+3), sh.make_signs(signs[6] | (signs[7] << 16)));
+ }
+
+ constexpr static int minv = 43;
+
+ SimpleBits bits;
+ SignHelper sh;
+ const __m256i idx_shift = _mm256_set_epi32(2, 4, 6, 8, 2, 4, 6, 8);
+ const __m256i idx_mask = _mm256_set1_epi32(0x300);
+ const __m256i min_value = _mm256_set1_epi8(minv);
+
+};
+
+struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
+ DequantizerIQ2XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ constexpr static int num_blocks = 16;
+
+ inline __m256i load_scales(int i) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+ auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales);
+ auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf));
+ auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1));
+ return _mm256_cvtepi8_epi16(scales8);
+ }
+ inline static void prepare_scales(const __m256i& all, __m256i * scales) {
+ auto scales_l = _mm256_castsi256_si128(all);
+ auto scales_h = _mm256_extractf128_si256(all, 1);
+ scales[0] = MM256_SET_M128I(scales_l, scales_l);
+ scales[1] = MM256_SET_M128I(scales_h, scales_h);
+ }
+
+ inline void new_block(int i, __m256i * scales) {
+ prepare_scales(load_scales(i), scales);
+ }
+ inline float new_block(int i, __m256i * scales, __m256i& mins) {
+ mins = load_scales(i);
+ prepare_scales(mins, scales);
+ return -d*minv;
+ }
+
+ struct Helper {
+ const __m256i mone = _mm256_set1_epi8(1);
+ const __m256i mask = _mm256_set1_epi64x(0x8040201008040201);
+ //const __m256i bhelper = _mm256_set_epi64x(0x8000008000808000, 0x0080800080000080, 0x8000008000808000, 0x0080800080000080);
+ const __m256i bhelper = load_bhelper();
+ const __m256i shuff1 = _mm256_set_epi64x(0x0606060606060606, 0x0404040404040404, 0x0202020202020202, 0x0000000000000000);
+ const __m256i shuff2 = _mm256_set_epi64x(0x0e0e0e0e0e0e0e0e, 0x0c0c0c0c0c0c0c0c, 0x0a0a0a0a0a0a0a0a, 0x0808080808080808);
+ static __m256i load_bhelper() {
+ static const uint8_t k_bit_helper[32] = {
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+ };
+ return _mm256_loadu_si256((const __m256i*)k_bit_helper);
+ }
+ };
+
+ union index_t {
+ __m256i vec;
+ uint16_t val[8];
+ };
+
+ inline static void make4(const __m256i& data, const __m256i& mask, __m256i * values) {
+ index_t idx;
+ idx.vec = _mm256_and_si256(data, mask);
+ values[0] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 3]], iq2xs_grid[idx.val[ 2]], iq2xs_grid[idx.val[ 1]], iq2xs_grid[idx.val[ 0]]);
+ values[1] = _mm256_set_epi64x(iq2xs_grid[idx.val[ 7]], iq2xs_grid[idx.val[ 6]], iq2xs_grid[idx.val[ 5]], iq2xs_grid[idx.val[ 4]]);
+ values[2] = _mm256_set_epi64x(iq2xs_grid[idx.val[11]], iq2xs_grid[idx.val[10]], iq2xs_grid[idx.val[ 9]], iq2xs_grid[idx.val[ 8]]);
+ values[3] = _mm256_set_epi64x(iq2xs_grid[idx.val[15]], iq2xs_grid[idx.val[14]], iq2xs_grid[idx.val[13]], iq2xs_grid[idx.val[12]]);
+ }
+ inline static void sign_value(const __m256i& sign_bits, const __m256i& shuffle, const __m256i& mask,
+ const __m256i& mone, __m256i& value) {
+ auto signs = _mm256_shuffle_epi8(sign_bits, shuffle);
+ signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, mask), mask);
+ value = _mm256_sign_epi8(value, _mm256_or_si256(signs, mone));
+ }
+ inline void sign_values(const __m256i& data, __m256i * values) const {
+#ifdef HAVE_FANCY_SIMD
+ auto partial_bits = _mm256_cvtepi16_epi8(_mm256_srli_epi16(data, 9));
+ auto pcnt = _mm_popcnt_epi8(partial_bits);
+ auto full_bits = _mm_or_si128(partial_bits, _mm_slli_epi16(_mm_and_si128(pcnt, _mm_set1_epi8(1)), 7));
+ const __mmask32 * m32 = (const __mmask32 *)&full_bits;
+ auto zero = _mm256_setzero_si256();
+ values[0] = _mm256_mask_sub_epi8(values[0], m32[0], zero, values[0]);
+ values[1] = _mm256_mask_sub_epi8(values[1], m32[1], zero, values[1]);
+ values[2] = _mm256_mask_sub_epi8(values[2], m32[2], zero, values[2]);
+ values[3] = _mm256_mask_sub_epi8(values[3], m32[3], zero, values[3]);
+#else
+ auto psb1 = _mm256_srli_epi16(data, 9);
+ auto psb2 = _mm256_srli_epi16(data, 13);
+ auto psbc = _mm256_xor_si256(psb1, psb2);
+ auto oddb = _mm256_shuffle_epi8(helper.bhelper, psbc);
+ auto full = _mm256_or_si256(psb1, oddb);
+ auto full_l = _mm256_castsi256_si128(full);
+ auto full_h = _mm256_extractf128_si256(full, 1);
+ auto full_1 = MM256_SET_M128I(full_l, full_l);
+ auto full_2 = MM256_SET_M128I(full_h, full_h);
+ sign_value(full_1, helper.shuff1, helper.mask, helper.mone, values[0]);
+ sign_value(full_1, helper.shuff2, helper.mask, helper.mone, values[1]);
+ sign_value(full_2, helper.shuff1, helper.mask, helper.mone, values[2]);
+ sign_value(full_2, helper.shuff2, helper.mask, helper.mone, values[3]);
+#endif
+ }
+ inline void make4_signed(const uint16_t * qs, const __m256i& m511,
+ const __m256i& min_value, __m256i * values) const {
+ auto q2 = _mm256_loadu_si256((const __m256i *)qs);
+ make4(q2, m511, values);
+ sign_values(q2, values);
+ for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value);
+ }
+ inline void make4(const uint16_t * qs, const __m256i& m511, __m256i * values, __m256i * q8) const {
+ auto q2 = _mm256_loadu_si256((const __m256i *)qs);
+ make4(q2, m511, values);
+ sign_values(q2, q8);
+ }
+
+ inline void prepare(int i, int j) {
+ make4_signed(x[i].qs + 16*j, idx_mask, min_value, bits.values);
+ }
+ inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
+ for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
+ make4(x[i].qs + 16*j, idx_mask, bits.values, q8_quants);
+ }
+
+ constexpr static int minv = 43;
+
+ SimpleBits bits;
+#ifndef HAVE_FANCY_SIMD
+ Helper helper;
+#endif
+ const __m256i idx_mask = _mm256_set1_epi16(511);
+ const __m256i min_value = _mm256_set1_epi8(minv);
+
+};
+
+struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
+ DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
+
+ constexpr static int num_blocks = 8;
+
+ union Data {
+ __m256i vec;
+ uint32_t val[8];
+ };
+
+ inline __m128i load_scales(int i) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+ const uint16_t * a16 = (const uint16_t *)x[i].qs;
+ auto scales = _mm_srli_epi16(_mm_set_epi16(a16[31], a16[27], a16[23], a16[19], a16[15], a16[11], a16[7], a16[3]), 12);
+ return _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi16(1));
+ }
+
+ inline void new_block(int i, __m256i * scales) {
+ auto sc16 = load_scales(i);
+ scales[0] = MM256_SET_M128I(sc16, sc16);
+ }
+ inline float new_block(int i, __m256i * scales, __m256i& mins) {
+ auto sc16 = load_scales(i);
+ mins = scb.shuffle(sc16);
+ scales[0] = MM256_SET_M128I(sc16, sc16);
+ return -d*minv;
+ }
+
+ inline static void make4(const uint32_t * aux32, __m256i * values) {
+ const uint8_t * aux8 = (const uint8_t *)aux32;
+ values[0] = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[ 1]], iq2xxs_grid[aux8[ 0]]);
+ values[1] = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[ 9]], iq2xxs_grid[aux8[ 8]]);
+ values[2] = _mm256_set_epi64x(iq2xxs_grid[aux8[19]], iq2xxs_grid[aux8[18]], iq2xxs_grid[aux8[17]], iq2xxs_grid[aux8[16]]);
+ values[3] = _mm256_set_epi64x(iq2xxs_grid[aux8[27]], iq2xxs_grid[aux8[26]], iq2xxs_grid[aux8[25]], iq2xxs_grid[aux8[24]]);
+ }
+
+ IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const {
+#ifdef HAVE_FANCY_SIMD
+ esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0);
+ esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2);
+#else
+ esh.sign_value(aux32[1], values[0]);
+ esh.sign_value(aux32[3], values[1]);
+ esh.sign_value(aux32[5], values[2]);
+ esh.sign_value(aux32[7], values[3]);
+#endif
+ }
+ inline void make4_signed(const uint32_t * aux32, const __m256i& min_value, __m256i * values) const {
+ make4(aux32, values);
+ sign_values(aux32, values);
+ for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value);
+ }
+ inline void make4(const uint32_t * aux32, __m256i * values, __m256i * q8) const {
+ make4(aux32, values);
+ sign_values(aux32, q8);
+ }
+ inline void prepare(int i, int j) {
+ Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
+ make4_signed(data.val, min_value, bits.values);
+ }
+ inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
+ for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
+ Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
+ make4(data.val, bits.values, q8_quants);
+ }
+
+ constexpr static int minv = 43;
+ SimpleBits bits;
+ Scales8KBase scb;
+ EvenSignHelper esh;
+ const __m256i min_value = _mm256_set1_epi8(minv);
+ const __m256i shuffle = _mm256_set_epi32(7, 5, 3, 1, 7, 5, 3, 1);
+};
+
+//
+// ============================== Legacy quants
+//
+
+struct DotHelper {
+ const __m256i m1 = _mm256_set1_epi16(1);
+#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
+ inline __m256i dot(__m256i x, __m256i y) const {
+ return _mm256_dpbusd_epi32(_mm256_setzero_si256(), x, y);
+ }
+#else
+ inline __m256i dot(__m256i x, __m256i y) const {
+ return _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x, y));
+ }
+#endif
+};
+
+struct SignedDot {
+ DotHelper helper;
+ inline __m256i compute(__m256i x, __m256i y) const {
+ return helper.dot(_mm256_sign_epi8(x, x), _mm256_sign_epi8(y, x));
+ }
+};
+struct UnsignedDot {
+ DotHelper helper;
+ inline __m256i compute(__m256i x, __m256i y) const {
+ return helper.dot(x, y);
+ }
+};
+
+template <typename Q8, typename Q8x4, typename Dot, bool can_pack = true> struct Sum4 {
+ Dot dot;
+ inline __m256i compute(const __m256i * qx, const Q8 * y) const {
+ const Q8x4 * y4 = (const Q8x4 *)y;
+ const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 8x block 0
+ const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 8x block 1
+ const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 8x block 2
+ const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 8x block 3
+ if constexpr (can_pack) {
+ const __m256i p01 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1
+ const __m256i p23 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3
+ return _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p01, p23)); // 0,1,2,3, 0,1,2,3
+ } else {
+ // Note to myself: this is much faster than using _mm256_hadd_epi32()
+ auto p01 = _mm256_add_epi32(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,1, 0,1, 0,1, 0,1
+ auto p23 = _mm256_add_epi32(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,3, 2,3, 2,3, 2,3
+ return _mm256_add_epi32(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,1,2,3, 0,1,2,3
+ }
+ }
+};
+// If I use this, it negatively impacts q4_1/q5_1 performance.
+//template <typename Q8, typename Q8x4, typename Dot> struct Sum4 {
+// Dot dot;
+// inline __m256i compute(const __m256i * qx, const Q8 * y) const {
+// const Q8x4 * y4 = (const Q8x4 *)y;
+// const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 8x block 0
+// const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 8x block 1
+// const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 8x block 2
+// const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 8x block 3
+// auto p01 = _mm256_add_epi32(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,1, 0,1, 0,1, 0,1
+// auto p23 = _mm256_add_epi32(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,3, 2,3, 2,3, 2,3
+// return _mm256_add_epi32(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,1,2,3, 0,1,2,3
+// }
+//};
+
+struct ScaleHelperQ8_0 {
+ inline __m128 prepare4(const block_q8_0 * y) {
+ const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y;
+ return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4->d));
+ }
+ inline __m128 prepare4(__m128 other_scales, const block_q8_0 * y) {
+ return _mm_mul_ps(other_scales, prepare4(y));
+ }
+ template <typename Q> inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); }
+ template <typename Q> inline float prepare1(float d, const Q * y) const { return d*prepare1(y); }
+};
+
+struct ScaleHelperQ_0 {
+ ggml_half scales8[4];
+ template <typename Q>
+ inline __m128 prepare4(const Q * y) {
+ for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;
+ return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8));
+ }
+ template <typename Q>
+ inline __m128 prepare4(__m128 other_scales, const Q * y) {
+ return _mm_mul_ps(other_scales, prepare4<Q>(y));
+ }
+ template <typename Q> inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); }
+ template <typename Q> inline float prepare1(float d, const Q * y) const { return d*prepare1(y); }
+};
+
+struct ScaleHelperQ8_1 {
+ template <typename Q>
+ inline __m256 prepare4(const Q * y) {
+ const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y;
+ return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)y4->d));
+ }
+ template <typename Q>
+ inline __m256 prepare4(__m256 other_scales, const Q * y) {
+ return _mm256_mul_ps(other_scales, prepare4<Q>(y));
+ }
+ template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
+ return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m));
+ }
+ template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const {
+ return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m));
+ }
+ std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {
+ return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));
+ }
+};
+
+struct ScaleHelperQ_1 {
+ uint32_t scales8[4];
+ const __m128i shuffle = _mm_set_epi16(0x0f0e, 0x0b0a, 0x0706, 0x0302, 0x0d0c, 0x0908, 0x0504, 0x0100);
+
+ template <typename Q>
+ inline __m256 prepare4(const Q * y) {
+ for (int j = 0; j < 4; ++j) {
+ // it is slightly faster to directly dereference (const uint32 *)&y[j].d, but some compilers
+ // complain that this breaks strict-aliasing rules.
+ memcpy(scales8 + j, &y[j].d, sizeof(uint32_t));
+ }
+ return _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *)scales8), shuffle));
+ }
+
+ template <typename Q>
+ inline __m256 prepare4(__m256 other_scales, const Q * y) {
+ return _mm256_mul_ps(other_scales, prepare4<Q>(y));
+ }
+
+ template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
+ return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m));
+ }
+ template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const {
+ return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m));
+ }
+ std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {
+ return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));
+ }
+};
+
+struct MinusType0 {
+ inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); }
+ inline float compute(float d, int) const { return d; }
+ inline float result(__m256 acc, int) const { return hsum_float_8(acc); }
+};
+
+template <int nrc_y> struct MinusType1 {
+ __m128 accm[nrc_y];
+ MinusType1() { for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm_setzero_ps(); }
+ inline __m256 compute(__m256 dm, int iy) {
+ const __m128 d = _mm256_castps256_ps128(dm);
+ const __m128 m = _mm256_extractf128_ps(dm, 1);
+ accm[iy] = _mm_add_ps(accm[iy], m);
+ return _mm256_set_m128(d, d);
+ }
+ inline float compute(const std::pair<float, float>& dm, int iy) {
+ accm[iy] = _mm_add_ps(accm[iy], _mm_set1_ps(dm.second*0.25f));
+ return dm.first;
+ }
+ inline float result(__m256 acc, int iy) const {
+ const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
+ return hsum_float_4(_mm_add_ps(sum, accm[iy]));
+ }
+};
+
+template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {
+ __m256 acc[nrc_y];
+ Minus accm;
+ AccumT() { for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = _mm256_setzero_ps(); }
+ template <typename Unpacker, typename Scales, typename Sum, typename Q8>
+ inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, const DataInfo& info, int ix) {
+ auto qx = unp.quants();
+ __m256 dall[nrc_y];
+ for (int i = 0; i < nb/4; ++i) {
+ auto other_scales = unp.set_block_4(i);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto s12 = scales.prepare4(other_scales, y[iy] + 4*i);
+ dall[iy] = accm.compute(s12, iy);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto pall = sum.compute(qx, y[iy] + 4*i);
+ acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]);
+ }
+ }
+ if (!is_multiple_of_4) {
+ for (int i = 4*(nb/4); i < nb; ++i) {
+ auto other_scales = unp.set_block(i);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto s12 = scales.prepare1(other_scales, y[iy] + i);
+ auto d = accm.compute(s12, iy);
+ const __m256i p0 = sum.dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));
+ acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, accm.result(acc[iy], iy));
+ //s[iy*bs] = accm.result(acc[iy], iy);
+ }
+ }
+};
+
+template <int nrc_y, bool is_multiple_of_4>
+using AccumType0 = AccumT<MinusType0, nrc_y, is_multiple_of_4>;
+
+template <int nrc_y, bool is_multiple_of_4>
+using AccumType1 = AccumT<MinusType1<nrc_y>, nrc_y, is_multiple_of_4>;
+
+using Sum4Type0 = Sum4<block_q8_0, block_q8_0_x4, SignedDot>;
+using Sum4Type1 = Sum4<block_q8_1, block_q8_1_x4, UnsignedDot>;
+using Sum4TypeQ80 = Sum4<block_q8_0, block_q8_0_x4, SignedDot, false>;
+
+template <typename Unpacker, typename AccumType, typename Scales, typename Q8, int nrc_y>
+void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) {
+ Unpacker unp(vx, bx);
+ typename Unpacker::Sum4T sum4;
+ Scales scales;
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ unp.set_row(ix);
+ AccumType accum;
+ accum.compute(nb, unp, scales, sum4, y, info, ix);
+ }
+}
+
+template <typename Unpacker, int nrc_y>
+void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%Unpacker::block_size() == 0);
+ Q8<nrc_y, block_q8_0> q8(info);
+ int nb = n/Unpacker::block_size();
+ if (nb%4 == 0) {
+ mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, true>, ScaleHelperQ8_0, block_q8_0, nrc_y>(
+ nb, vx, bx, info, q8.y, nrc_x
+ );
+ } else {
+ mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, false>, ScaleHelperQ8_0, block_q8_0, nrc_y>(
+ nb, vx, bx, info, q8.y, nrc_x
+ );
+ }
+}
+
+template <typename Unpacker, int nrc_y>
+void mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%Unpacker::block_size() == 0);
+ Q8<nrc_y, block_q8_1> q8(info);
+ int nb = n/Unpacker::block_size();
+ if (nb%4 == 0) {
+ mul_mat_qX_q8_Helper<Unpacker, AccumType1<nrc_y, true>, ScaleHelperQ8_1, block_q8_1, nrc_y>(
+ nb, vx, bx, info, q8.y, nrc_x
+ );
+ } else {
+ mul_mat_qX_q8_Helper<Unpacker, AccumType1<nrc_y, false>, ScaleHelperQ8_1, block_q8_1, nrc_y>(
+ nb, vx, bx, info, q8.y, nrc_x
+ );
+ }
+}
+
+struct Dequantizer4bit {
+ const __m256i m4 = _mm256_set1_epi8(0xf);
+ inline __m256i dequant(const uint8_t * qs) const {
+ const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);
+ return _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128), m4);
+ }
+};
+
+struct Q8_0_Dequantizer {
+ inline __m256i dequant(const block_q8_0 * x) const {
+ return _mm256_loadu_si256((const __m256i *)x->qs);
+ }
+};
+
+struct Q4_0_Dequantizer {
+ Dequantizer4bit b4;
+ const __m256i m8 = _mm256_set1_epi8(-8);
+ inline __m256i dequant(const block_q4_0 * x) const {
+ return _mm256_add_epi8(b4.dequant(x->qs), m8);
+ }
+};
+
+struct IQ4_NL_Dequantizer {
+ Dequantizer4bit b4;
+ const __m256i values = load_values();
+ inline __m256i dequant(const block_iq4_nl * x) const {
+ return _mm256_shuffle_epi8(values, b4.dequant(x->qs));
+ }
+ static __m256i load_values() {
+ static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+ auto aux = _mm_loadu_si128((const __m128i *)iq4nl_values);
+ return MM256_SET_M128I(aux, aux);
+ }
+};
+
+struct Q4_1_Dequantizer {
+ Dequantizer4bit b4;
+ inline __m256i dequant(const block_q4_1 * x) const {
+ return b4.dequant(x->qs);
+ }
+};
+
+struct HBitDequantizer {
+ const __m256i shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);
+ const __m256i mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
+ const __m256i minus1 = _mm256_set1_epi64x(-1);
+ inline __m256i to_bytes(const uint8_t * bits) const {
+ // Note: Data in all ggml quants is at least 2-byte aligned.
+ // => we can cast to uint16_t and use or on two consecutive entries
+ // which is faster than memcpy
+ const uint16_t * aux16 = (const uint16_t *)bits;
+ const uint32_t aux32 = aux16[0] | (aux16[1] << 16);
+ //uint32_t aux32; memcpy(&aux32, bits, sizeof(uint32_t));
+ __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(aux32), shuffle);
+ bytes = _mm256_or_si256(bytes, mask);
+ return _mm256_cmpeq_epi8(bytes, minus1);
+ }
+};
+
+struct Q5_0_Dequantizer {
+ Dequantizer4bit b4;
+ HBitDequantizer hbit;
+ const __m256i mh = _mm256_set1_epi8((char)0xF0);
+ inline __m256i dequant(const block_q5_0 * x) const {
+ const __m256i vqh = _mm256_andnot_si256(hbit.to_bytes(x->qh), mh);
+ return _mm256_or_si256(b4.dequant(x->qs), vqh);
+ }
+};
+
+struct Q5_1_Dequantizer {
+ Dequantizer4bit b4;
+ HBitDequantizer hbit;
+ const __m256i mh = _mm256_set1_epi8(0x10);
+ inline __m256i dequant(const block_q5_1 * x) const {
+ const __m256i vqh = _mm256_and_si256(hbit.to_bytes(x->qh), mh);
+ return _mm256_or_si256(b4.dequant(x->qs), vqh);
+ }
+};
+
+template <typename Q, typename Scales, typename Dequantizer>
+struct Q_Unpacker {
+ Q_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const Q*)cx_0), bx(bx) {}
+
+ const char * cx_0;
+ const Q * x;
+ size_t bx;
+
+ Scales scales;
+ Dequantizer deq;
+
+ __m256i qx[4];
+
+ inline const __m256i* quants() const { return qx; }
+
+ inline void set_row(int ix) { x = (const Q*)(cx_0 + ix*bx); }
+
+ inline auto set_block_4(int i) {
+ for (int j = 0; j < 4; ++j) {
+ qx[j] = deq.dequant(x + 4*i + j);
+ }
+ return scales.prepare4(x + 4*i);
+ }
+ inline auto set_block(int i) {
+ qx[0] = deq.dequant(x + i);
+ return scales.prepare1(x + i);
+ }
+};
+
+struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {
+ Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ80;
+ inline static int block_size() { return QK8_0; }
+};
+struct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_Dequantizer> {
+ Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ80;
+ inline static int block_size() { return QK4_0; }
+};
+struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0, IQ4_NL_Dequantizer> {
+ IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ80;
+ inline static int block_size() { return QK4_NL; }
+};
+struct Q5_0_Unpacker final : public Q_Unpacker<block_q5_0, ScaleHelperQ_0, Q5_0_Dequantizer> {
+ Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4TypeQ80;
+ inline static int block_size() { return QK5_0; }
+};
+struct Q4_1_Unpacker final : public Q_Unpacker<block_q4_1, ScaleHelperQ_1, Q4_1_Dequantizer> {
+ Q4_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4Type1;
+ inline static int block_size() { return QK4_1; }
+};
+struct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_Dequantizer> {
+ Q5_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
+ using Sum4T = Sum4Type1;
+ inline static int block_size() { return QK4_1; }
+};
+
+// float matrices - we handle f16 and f32, but only to f32 result
+
+struct QFBase {
+#ifdef __AVX512F__
+ constexpr static int k_step = 16;
+ using Data = __m512;
+ using Acc = __m512;
+ static inline Data load(const ggml_half * x) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x)); }
+ static inline Data load(const float * x) { return _mm512_loadu_ps(x); }
+ static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ return _mm512_fmadd_ps(y, x, prev);
+ }
+ static inline Acc acc_first(const Data& y, const Data& x) {
+ return _mm512_mul_ps(y, x);
+ }
+ static inline float hsum(Acc acc) {
+ return _mm512_reduce_add_ps(acc);
+ }
+ template <typename Float>
+ static inline Data load4Floats(const Float * x) {
+ return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0);
+ }
+#else
+ constexpr static int k_step = 8;
+ using Data = __m256;
+ using Acc = __m256;
+ static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); }
+ static inline Data load(const float * x) { return _mm256_loadu_ps(x); }
+ static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ return _mm256_fmadd_ps(y, x, prev);
+ }
+ static inline Acc acc_first(const Data& y, const Data& x) {
+ return _mm256_mul_ps(y, x);
+ }
+ static inline float hsum(Acc acc) {
+ return hsum_float_8(acc);
+ }
+ template <typename Float>
+ static inline Data load4Floats(const Float * x) {
+ return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0);
+ }
+#endif
+ static inline __m128 load128(const ggml_half * x) { return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x)); }
+ static inline __m128 load128(const float * x) { return _mm_loadu_ps(x); }
+};
+template <typename Float, int nrc_in> struct QFT final : public QFBase {
+ constexpr static int nrc = nrc_in;
+ QFT(const DataInfo& info) {
+ for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)info.src1_row(iy);
+ }
+ QFT(const char * cx, size_t bx) {
+ for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)(cx + iy*bx);
+ }
+ IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
+ IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); }
+ const Float * y[nrc];
+};
+
+template <typename Qy, typename Qx>
+IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+ assert(n%QFBase::k_step == 0);
+ int nb = n/QFBase::k_step;
+ int nb4 = n/4;
+ Qy y(info);
+ Qx x(cx + ix0*bx, bx);
+ QFBase::Data xv[Qx::nrc];
+ QFBase::Acc acc[Qx::nrc*Qy::nrc];
+ auto yv = y.load1(0, 0);
+ for (int ix = 0; ix < Qx::nrc; ++ix) {
+ xv[ix] = x.load1(ix, 0);
+ acc[ix] = QFBase::acc_first(yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc; ++iy) {
+ yv = y.load1(iy, 0);
+ for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]);
+ }
+ for (int i = 1; i < nb; ++i) {
+ yv = y.load1(0, i);
+ for (int ix = 0; ix < Qx::nrc; ++ix) {
+ xv[ix] = x.load1(ix, i);
+ acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc; ++iy) {
+ yv = y.load1(iy, i);
+ for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
+ }
+ }
+ for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {
+ yv = y.load_tail(0, i);
+ for (int ix = 0; ix < Qx::nrc; ++ix) {
+ xv[ix] = x.load_tail(ix, i);
+ acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < Qy::nrc; ++iy) {
+ yv = y.load_tail(iy, i);
+ for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
+ }
+ }
+ for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
+}
+
+// This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done
+// in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in
+// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.
+template <int nrc_y, typename FloatX, typename FloatY>
+void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QFBase::k_step == 0);
+#ifdef __AVX512F__
+ constexpr int k_nx = 5;
+#else
+ constexpr int k_nx = 2;
+#endif
+ const char * cx = (const char *)vx;
+ for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
+ }
+ int last_x = k_nx*(nrc_x/k_nx);
+ if (last_x == nrc_x) return;
+ int nx = nrc_x - last_x;
+ switch (nx) {
+ case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
+#ifdef __AVX512F__
+ case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
+ case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
+ case 4: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break;
+#endif
+ }
+}
+
+//
+// Tiled Q8_0 x Q8_0 implementation. Not used as the templated legacy quant implementation
+// above is faster. Left behind so we remember we tried.
+//
+template <int nrc> struct Q80 {
+ constexpr static int nrc_y = nrc;
+ Q80(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy);
+ }
+ IQK_ALWAYS_INLINE __m256i load1(int iy, int i) const { return _mm256_loadu_si256((const __m256i *)y[iy][i].qs); }
+ IQK_ALWAYS_INLINE float scale(int iy, int i) const { return GGML_FP16_TO_FP32(y[iy][i].d); }
+
+ const block_q8_0 * y[nrc_y];
+};
+inline __m256i mul_q80(__m256i x, __m256i y) {
+ auto ux = _mm256_sign_epi8(x, x);
+#ifdef HAVE_FANCY_SIMD
+ return _mm256_dpbusd_epi32(_mm256_setzero_si256(), ux, _mm256_sign_epi8(y, x));
+#else
+ return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(ux, _mm256_sign_epi8(y, x)));
+#endif
+}
+template <int nrc_y>
+void mul_mat_q80_q80_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n%QK8_0 == 0);
+ constexpr int k_nx = 4;
+ int nb = n/QK8_0;
+ Q80<nrc_y> q8(info);
+ const block_q8_0 * x[k_nx];
+ float ds[k_nx];
+ __m256 acc[k_nx*nrc_y];
+ __m256i xv[k_nx];
+ for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ int ix0 = k_nx*ix;
+ for (int kx = 0; kx < k_nx; ++kx) {
+ x[kx] = (const block_q8_0 *)((const char *)vx + (ix0 + kx)*bx);
+ ds[kx] = GGML_FP16_TO_FP32(x[kx][0].d);
+ xv[kx] = _mm256_loadu_si256((const __m256i *)x[kx][0].qs);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto yv = q8.load1(iy, 0);
+ float d = q8.scale(iy, 0);
+ for (int kx = 0; kx < k_nx; ++kx) {
+ auto dot = mul_q80(yv, xv[kx]);
+ acc[k_nx*iy + kx] = _mm256_mul_ps(_mm256_set1_ps(ds[kx]*d), _mm256_cvtepi32_ps(dot));
+ }
+ }
+ for (int i = 1; i < nb; ++i) {
+ for (int kx = 0; kx < k_nx; ++kx) {
+ ds[kx] = GGML_FP16_TO_FP32(x[kx][i].d);
+ xv[kx] = _mm256_loadu_si256((const __m256i *)x[kx][i].qs);
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto yv = q8.load1(iy, i);
+ float d = q8.scale(iy, i);
+ for (int kx = 0; kx < k_nx; ++kx) {
+ auto dot = mul_q80(yv, xv[kx]);
+ acc[k_nx*iy + kx] = _mm256_fmadd_ps(_mm256_set1_ps(ds[kx]*d), _mm256_cvtepi32_ps(dot), acc[k_nx*iy + kx]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ for (int kx = 0; kx < k_nx; ++kx) info.store(ix0+kx, iy, hsum_float_8(acc[k_nx*iy+kx]));
+ }
+ }
+ int last_x = k_nx*(nrc_x/k_nx);
+ if (last_x == nrc_x) return;
+ // TODO: handle remaining rows
+}
+
+template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
+ if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||
+ std::is_same_v<Dequantizer, Q8_0_Unpacker> || std::is_same_v<Dequantizer, IQ4_NL_Unpacker>) {
+ m.funcs[0] = mul_mat_qX_0_q8_0_T<Dequantizer, 1>;
+ m.funcs[1] = mul_mat_qX_0_q8_0_T<Dequantizer, 2>;
+ m.funcs[2] = mul_mat_qX_0_q8_0_T<Dequantizer, 3>;
+ m.funcs[3] = mul_mat_qX_0_q8_0_T<Dequantizer, 4>;
+ m.funcs[4] = mul_mat_qX_0_q8_0_T<Dequantizer, 5>;
+ m.funcs[5] = mul_mat_qX_0_q8_0_T<Dequantizer, 6>;
+ m.funcs[6] = mul_mat_qX_0_q8_0_T<Dequantizer, 7>;
+ m.funcs[7] = mul_mat_qX_0_q8_0_T<Dequantizer, 8>;
+ }
+ else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker>) {
+ m.funcs[0] = mul_mat_qX_1_q8_1_T<Dequantizer, 1>;
+ m.funcs[1] = mul_mat_qX_1_q8_1_T<Dequantizer, 2>;
+ m.funcs[2] = mul_mat_qX_1_q8_1_T<Dequantizer, 3>;
+ m.funcs[3] = mul_mat_qX_1_q8_1_T<Dequantizer, 4>;
+ m.funcs[4] = mul_mat_qX_1_q8_1_T<Dequantizer, 5>;
+ m.funcs[5] = mul_mat_qX_1_q8_1_T<Dequantizer, 6>;
+ m.funcs[6] = mul_mat_qX_1_q8_1_T<Dequantizer, 7>;
+ m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>;
+ }
+ else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ3S> || std::is_same_v<Dequantizer, DequantizerIQ3XXS> ||
+ std::is_same_v<Dequantizer, DequantizerIQ2S> || std::is_same_v<Dequantizer, DequantizerIQ2XS> ||
+ std::is_same_v<Dequantizer, DequantizerIQ2XXS>) {
+ m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
+ m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
+ m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>;
+ m.funcs[3] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 4>;
+ m.funcs[4] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 5>;
+ m.funcs[5] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 6>;
+ m.funcs[6] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 7>;
+ m.funcs[7] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 8>;
+ }
+ else {
+#ifdef HAVE_FANCY_SIMD
+ m.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
+ m.funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;
+ m.funcs[2] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 3>;
+ m.funcs[3] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 4>;
+ m.funcs[4] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 5>;
+ m.funcs[5] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 6>;
+ m.funcs[6] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 7>;
+ m.funcs[7] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 8>;
+#else
+ if constexpr (std::is_same_v<Dequantizer, DequantizerQ2K> ||
+ std::is_same_v<Dequantizer, DequantizerQ3K> ||
+ std::is_same_v<Dequantizer, DequantizerQ6K>) {
+ m.funcs[0] = mul_mat_qY_K_q8_K_T<Dequantizer, 1>;
+ m.funcs[1] = mul_mat_qY_K_q8_K_T<Dequantizer, 2>;
+ m.funcs[2] = mul_mat_qY_K_q8_K_T<Dequantizer, 3>;
+ m.funcs[3] = mul_mat_qY_K_q8_K_T<Dequantizer, 4>;
+ m.funcs[4] = mul_mat_qY_K_q8_K_T<Dequantizer, 5>;
+ m.funcs[5] = mul_mat_qY_K_q8_K_T<Dequantizer, 6>;
+ m.funcs[6] = mul_mat_qY_K_q8_K_T<Dequantizer, 7>;
+ m.funcs[7] = mul_mat_qY_K_q8_K_T<Dequantizer, 8>;
+ } else {
+ m.funcs[0] = mul_mat_qX_K_q8_K_T<Dequantizer, 1>;
+ m.funcs[1] = mul_mat_qX_K_q8_K_T<Dequantizer, 2>;
+ m.funcs[2] = mul_mat_qX_K_q8_K_T<Dequantizer, 3>;
+ m.funcs[3] = mul_mat_qX_K_q8_K_T<Dequantizer, 4>;
+ m.funcs[4] = mul_mat_qX_K_q8_K_T<Dequantizer, 5>;
+ m.funcs[5] = mul_mat_qX_K_q8_K_T<Dequantizer, 6>;
+ m.funcs[6] = mul_mat_qX_K_q8_K_T<Dequantizer, 7>;
+ m.funcs[7] = mul_mat_qX_K_q8_K_T<Dequantizer, 8>;
+ }
+#endif
+ }
+}
+
+template <typename FloatX, typename FloatY>
+void set_mul_mat_f(MulMat& mm) {
+ for (auto& f : mm.funcs) f = nullptr;
+ mm.funcs[0] = mul_mat_fX_fY_T<1, FloatX, FloatY>;
+ mm.funcs[1] = mul_mat_fX_fY_T<2, FloatX, FloatY>;
+ mm.funcs[2] = mul_mat_fX_fY_T<3, FloatX, FloatY>;
+ mm.funcs[3] = mul_mat_fX_fY_T<4, FloatX, FloatY>;
+ mm.funcs[4] = mul_mat_fX_fY_T<5, FloatX, FloatY>;
+#ifndef __AVX512F__
+ mm.funcs[5] = mul_mat_fX_fY_T<6, FloatX, FloatY>;
+#endif
+}
+
+bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
+
+ (void)Ny;
+
+ if (typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) {
+ if (ne00 % 4) return false;
+ }
+ if (typeA == GGML_TYPE_F16) {
+ switch (typeB) {
+ case GGML_TYPE_F16: set_mul_mat_f<ggml_half, ggml_half>(mm); break;
+ case GGML_TYPE_F32: set_mul_mat_f<ggml_half, float>(mm); break;
+ default: return false;
+ }
+ return true;
+ }
+ if (typeA == GGML_TYPE_F32) {
+ switch (typeB) {
+ case GGML_TYPE_F16: set_mul_mat_f<float, ggml_half>(mm); break;
+ case GGML_TYPE_F32: set_mul_mat_f<float, float>(mm); break;
+ default: return false;
+ }
+ return true;
+ }
+
+ auto expected_typeB = GGML_TYPE_Q8_K;
+
+ switch (typeA) {
+ case GGML_TYPE_Q2_K:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerQ2K>(mm);
+ break;
+ case GGML_TYPE_Q3_K:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerQ3K>(mm);
+ break;
+ case GGML_TYPE_Q4_K:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerQ4K>(mm);
+ break;
+ case GGML_TYPE_Q5_K:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerQ5K>(mm);
+ break;
+ case GGML_TYPE_Q6_K:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerQ6K>(mm);
+ break;
+ case GGML_TYPE_IQ4_XS:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerIQ4XS>(mm);
+ break;
+ case GGML_TYPE_IQ3_S:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerIQ3S>(mm);
+ break;
+ case GGML_TYPE_IQ3_XXS:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerIQ3XXS>(mm);
+ break;
+ case GGML_TYPE_IQ2_S:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerIQ2S>(mm);
+ break;
+ case GGML_TYPE_IQ2_XS:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerIQ2XS>(mm);
+ break;
+ case GGML_TYPE_IQ2_XXS:
+ assert (ne00 % QK_K == 0);
+ MulMat::set_functions<DequantizerIQ2XXS>(mm);
+ break;
+ case GGML_TYPE_IQ1_BN:
+ assert (ne00 % QK_IQ1BN == 0);
+ mm.funcs[0] = mul_mat_iq1bn_q8_K64<1>;
+ mm.funcs[1] = mul_mat_iq1bn_q8_K64<2>;
+ mm.funcs[2] = mul_mat_iq1bn_q8_K64<3>;
+ mm.funcs[3] = mul_mat_iq1bn_q8_K64<4>;
+ mm.funcs[4] = mul_mat_iq1bn_q8_K64<5>;
+ mm.funcs[5] = mul_mat_iq1bn_q8_K64<6>;
+ mm.funcs[6] = mul_mat_iq1bn_q8_K64<7>;
+ mm.funcs[7] = mul_mat_iq1bn_q8_K64<8>;
+ expected_typeB = GGML_TYPE_Q8_K64;
+ break;
+ case GGML_TYPE_IQ2_BN:
+ assert (ne00 % QK_IQ1BN == 0);
+ mm.funcs[0] = mul_mat_iq2bn_q8_K64<1>;
+ mm.funcs[1] = mul_mat_iq2bn_q8_K64<2>;
+ mm.funcs[2] = mul_mat_iq2bn_q8_K64<3>;
+ mm.funcs[3] = mul_mat_iq2bn_q8_K64<4>;
+ mm.funcs[4] = mul_mat_iq2bn_q8_K64<5>;
+ mm.funcs[5] = mul_mat_iq2bn_q8_K64<6>;
+ mm.funcs[6] = mul_mat_iq2bn_q8_K64<7>;
+ mm.funcs[7] = mul_mat_iq2bn_q8_K64<8>;
+ expected_typeB = GGML_TYPE_Q8_K64;
+ break;
+ case GGML_TYPE_Q4_0:
+ assert (ne00 % QK4_0 == 0);
+ MulMat::set_functions<Q4_0_Unpacker>(mm);
+ expected_typeB = GGML_TYPE_Q8_0;
+ break;
+ case GGML_TYPE_Q4_1:
+ assert (ne00 % QK4_1 == 0);
+ MulMat::set_functions<Q4_1_Unpacker>(mm);
+ expected_typeB = GGML_TYPE_Q8_1;
+ break;
+ case GGML_TYPE_Q5_0:
+ assert (ne00 % QK5_0 == 0);
+ MulMat::set_functions<Q5_0_Unpacker>(mm);
+ expected_typeB = GGML_TYPE_Q8_0;
+ break;
+ case GGML_TYPE_Q5_1:
+ assert (ne00 % QK5_1 == 0);
+ MulMat::set_functions<Q5_1_Unpacker>(mm);
+ expected_typeB = GGML_TYPE_Q8_1;
+ break;
+ case GGML_TYPE_Q8_0:
+ assert (ne00 % QK8_0 == 0);
+ MulMat::set_functions<Q8_0_Unpacker>(mm);
+ expected_typeB = GGML_TYPE_Q8_0;
+ break;
+ case GGML_TYPE_IQ4_NL:
+ assert (ne00 % QK4_NL == 0);
+ MulMat::set_functions<IQ4_NL_Unpacker>(mm);
+ expected_typeB = GGML_TYPE_Q8_0;
+ break;
+
+ default:
+ return false;
+ }
+
+ return ggml_type(typeB) == expected_typeB;
+}
+
+} // namespace
+
+
+#else // __aarch64__
+
+namespace {
+
+template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);
+ }
+
+ inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); }
+ inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); }
+ inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); }
+ inline int16x8_t load_bsums8(int iy, int i) const {
+ auto q8s = vld1q_s16_x2(y[iy][i].bsums);
+ return vpaddq_s16(q8s.val[0], q8s.val[1]);
+ }
+ inline float scale(int iy, int i) const { return y[iy][i].d; }
+
+ const block_q8 * y[nrc_y];
+};
+
+template <typename Q8>
+inline void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
+ const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) {
+ auto mzero = vdupq_n_s32(0);
+ auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
+ auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
+ vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1]); // block 1
+ auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
+ auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
+ vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1]); // block 2
+ auto p12 = vpaddq_s32(p1, p2);
+
+ auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
+ auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
+ vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1]); // block 1
+ auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
+ auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
+ vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1]); // block 2
+ auto p34 = vpaddq_s32(p3, p4);
+
+ auto pall = vpaddq_s32(p12, p34);
+ sumi = vmlaq_s32(sumi, scales.val[j], pall);
+}
+
+template <typename Q8>
+inline void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
+ const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {
+
+ auto mzero = vdupq_n_s32(0);
+ auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
+ auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
+ ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1,
+ auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
+ auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
+ ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4,
+ auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3
+ sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12);
+
+ auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
+ auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
+ ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5,
+ auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
+ auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
+ ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7,
+ auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7
+ sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34);
+}
+
+template <typename Q8>
+inline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto q8s = q8.load_bsums8(iy, i);
+ int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s));
+ int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s));
+ float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2));
+ acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
+ }
+}
+template <typename Q8>
+inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto q8s = q8.load_bsums(iy, i);
+ int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0]));
+ int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0]));
+ int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1]));
+ int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1]));
+ float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4)));
+ acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
+ }
+}
+
+struct Scales8 {
+ uint32_t utmp[4];
+ const uint8_t * sc8 = (const uint8_t *)utmp;
+ template <typename Q8, typename Qx>
+ inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) {
+ make_q4_scales(x.scales, utmp);
+ int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8));
+ accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin));
+
+ uint8x8_t scales8 = vld1_u8(sc8);
+ uint16x8_t scales16 = vmovl_u8(scales8);
+ int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))),
+ vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))};
+ return scales;
+ }
+};
+
+struct Q4bits {
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
+ uint8x16x4_t b1, b2;
+ inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const {
+ b.val[0] = vandq_u8(val[0], m4b);
+ b.val[2] = vshrq_n_u8(val[0], 4);
+ b.val[1] = vandq_u8(val[1], m4b);
+ b.val[3] = vshrq_n_u8(val[1], 4);
+ }
+ inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const {
+ b.val[0] = vandq_u8(val[0], m4b);
+ b.val[1] = vshrq_n_u8(val[0], 4);
+ b.val[2] = vandq_u8(val[1], m4b);
+ b.val[3] = vshrq_n_u8(val[1], 4);
+ }
+ inline void prepare(const uint8_t * qs) {
+ auto q4bits = vld1q_u8_x2(qs);
+ prepare4(b1, q4bits.val);
+ q4bits = vld1q_u8_x2(qs+32);
+ prepare4(b2, q4bits.val);
+ }
+ inline void prepare_v2(const uint8_t * qs) {
+ auto q4bits = vld1q_u8_x4(qs);
+ prepare4(b1, q4bits.val+0);
+ prepare4(b2, q4bits.val+2);
+ }
+ inline void prepare64(const uint8_t * qs) {
+ auto q4bits = vld1q_u8_x4(qs);
+ b1.val[0] = vandq_u8(q4bits.val[0], m4b);
+ b1.val[1] = vandq_u8(q4bits.val[1], m4b);
+ b1.val[2] = vandq_u8(q4bits.val[2], m4b);
+ b1.val[3] = vandq_u8(q4bits.val[3], m4b);
+ b2.val[0] = vshrq_n_u8(q4bits.val[0], 4);
+ b2.val[1] = vshrq_n_u8(q4bits.val[1], 4);
+ b2.val[2] = vshrq_n_u8(q4bits.val[2], 4);
+ b2.val[3] = vshrq_n_u8(q4bits.val[3], 4);
+ }
+ inline void prepare16(const uint8_t * qs) {
+ auto q4bits = vld1q_u8_x2(qs);
+ prepare4_16(b1, q4bits.val);
+ q4bits = vld1q_u8_x2(qs+32);
+ prepare4_16(b2, q4bits.val);
+ }
+ inline void prepare16_v2(const uint8_t * qs) {
+ auto q4bits = vld1q_u8_x4(qs);
+ prepare4_16(b1, q4bits.val+0);
+ prepare4_16(b2, q4bits.val+2);
+ }
+};
+
+struct Q2bits {
+ const uint8x16_t m4b = vdupq_n_u8(0x03);
+ uint8x16x4_t b1, b2;
+ inline void prepare(const uint8_t * qs) {
+ auto q2bits = vld1q_u8_x2(qs);
+ b1.val[0] = vandq_u8(q2bits.val[0], m4b);
+ b1.val[1] = vandq_u8(q2bits.val[1], m4b);
+
+ q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
+ q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
+ b1.val[2] = vandq_u8(q2bits.val[0], m4b);
+ b1.val[3] = vandq_u8(q2bits.val[1], m4b);
+
+ q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
+ q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
+ b2.val[0] = vandq_u8(q2bits.val[0], m4b);
+ b2.val[1] = vandq_u8(q2bits.val[1], m4b);
+
+ q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
+ q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
+ b2.val[2] = vandq_u8(q2bits.val[0], m4b);
+ b2.val[3] = vandq_u8(q2bits.val[1], m4b);
+ }
+};
+
+template <typename block_q>
+struct BaseDequantizer {
+ BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}
+ inline void new_row(int ix) { x = (const block_q *)((const char *)vx + ix*bx); }
+ const void * vx;
+ const block_q * x;
+ const size_t bx;
+ const int nrc;
+};
+
+struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
+ DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ return s8.process_scales_mins(x[i], q8, i, acc);
+ }
+ inline void prepare(int i, int j) {
+ if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
+ else bits.prepare(x[i].qs+64*j);
+ }
+
+ Q4bits bits;
+ Scales8 s8;
+
+ float d;
+};
+
+struct HighBit5 {
+ const uint8x16_t mhb = vdupq_n_u8(0x10);
+ uint8x16x2_t bits;
+ inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
+ b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb));
+ b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb));
+ b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb));
+ b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb));
+
+ b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
+ b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
+ b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
+ b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
+
+ if (do_shift) {
+ bits.val[0] = vshrq_n_u8(bits.val[0], 4);
+ bits.val[1] = vshrq_n_u8(bits.val[1], 4);
+ }
+ }
+};
+
+struct HighBit3 {
+ const uint8x16_t mhb = vdupq_n_u8(0x04);
+ uint8x16x2_t bits;
+ inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
+ b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
+ b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
+ b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
+ b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
+
+ b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb));
+ b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb));
+ b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb));
+ b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb));
+
+ if (do_shift) {
+ bits.val[0] = vshrq_n_u8(bits.val[0], 4);
+ bits.val[1] = vshrq_n_u8(bits.val[1], 4);
+ }
+ }
+};
+
+struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
+ DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ h.bits = vld1q_u8_x2(x[i].qh);
+ return s8.process_scales_mins(x[i], q8, i, acc);
+ }
+ inline void prepare(int i, int j) {
+ if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
+ else bits.prepare(x[i].qs+64*j);
+ h.apply(bits.b1, bits.b2, j == 0);
+ }
+
+ Q4bits bits;
+ HighBit5 h;
+ Scales8 s8;
+
+ uint8x16x2_t hbits;
+
+ float d;
+};
+
+inline int32x4x4_t make_wider(const int16x8x2_t& scales16) {
+ int32x4x4_t scales = {
+ vmovl_s16(vget_low_s16 (scales16.val[0])),
+ vmovl_s16(vget_high_s16(scales16.val[0])),
+ vmovl_s16(vget_low_s16 (scales16.val[1])),
+ vmovl_s16(vget_high_s16(scales16.val[1])),
+ };
+ return scales;
+}
+
+template <typename Q8>
+inline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) {
+ int16x8x2_t scales16;
+ scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
+ scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
+ accum_mins_16(scales16, q8, acc, i, c);
+ return make_wider(scales16);
+}
+
+struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
+ DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d);
+ }
+ inline void prepare(int i, int j) {
+
+ auto hbits = vld1q_u8_x2(x[i].qh + 32*j);
+
+ bits.prepare64(x[i].ql+64*j);
+ bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb));
+ bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb));
+ bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb));
+ bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb));
+
+ bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb));
+ bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb));
+ bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb));
+ bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb));
+
+ }
+
+ Q4bits bits;
+
+ const uint8x16_t mhb = vdupq_n_u8(0x30);
+
+ float d;
+};
+
+struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
+ DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ h.bits = vld1q_u8_x2(x[i].hmask);
+ mask = vdupq_n_u8(0x01);
+ const uint16_t * sc16 = (const uint16_t *)x[i].scales;
+ uint32_t aux0 = sc16[0] | (sc16[1] << 16);
+ uint32_t aux1 = sc16[2] | (sc16[3] << 16);
+ uint32_t aux2 = sc16[4] | (sc16[5] << 16);
+ aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030);
+ aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);
+ aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);
+ aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);
+ auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32));
+ if (nrc > 1) {
+ return process_scales_mins_16(scales8, q8, acc, i, -4.f*d);
+ }
+ int16x8x2_t scales16;
+ scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
+ scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
+ return make_wider(scales16);
+ }
+
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs+32*j);
+ if (nrc > 1) {
+ h.apply(bits.b1, bits.b2, j == 0);
+ } else {
+ auto minus4 = vdupq_n_u8(0xfc);
+ auto zero = vdupq_n_u8(0);
+ bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
+ bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
+ mask = vshlq_n_u8(mask, 1);
+ bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
+ bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
+ mask = vshlq_n_u8(mask, 1);
+ bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
+ bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
+ mask = vshlq_n_u8(mask, 1);
+ bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
+ bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
+ mask = vshlq_n_u8(mask, 1);
+ }
+ }
+
+ uint32_t aux32[4];
+
+ Q2bits bits;
+
+ uint8x16_t mask;
+ HighBit3 h;
+
+ float d;
+};
+
+struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
+ DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return true; }
+
+ template <typename Q8>
+ inline void process_scales(int i, const Q8& q8, float32x4_t * acc) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ auto scales_and_mins = vld1q_u8(x[i].scales);
+ auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4));
+ int16x8x2_t scales16;
+ scales16.val[0] = vmovl_s8(vget_low_s8(mins8));
+ scales16.val[1] = vmovl_s8(vget_high_s8(mins8));
+ accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin));
+
+ scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf));
+ }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ process_scales(i, q8, acc);
+ int16x8x2_t scales16;
+ scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8)));
+ scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8)));
+ return make_wider(scales16);
+ }
+
+ template <typename Q8>
+ inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {
+ auto m1 = vdupq_n_u8(1);
+ auto shuffle = vdupq_n_u8(8*j);
+ bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
+ sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
+ vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
+
+ auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
+ sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
+ vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
+
+ auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
+ sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),
+ vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);
+
+ auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
+ sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),
+ vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);
+ }
+ }
+
+ inline void prepare(int i, int j) {
+ bits.prepare(x[i].qs+32*j);
+ }
+
+ uint32_t aux32[4];
+
+ uint8x16_t scales8;
+
+ Q2bits bits;
+
+ float d;
+};
+
+// ============================= i-quants
+
+struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
+
+ static int8x16_t load_values() {
+ static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+ return vld1q_s8(iq4nl_values);
+ }
+
+ DequantizerIQ4XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(load_values()) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ inline void new_row(int ix) { x = (const block_iq4_xs *)((const char *)vx + bx*ix); }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
+ (void)q8;
+ (void)acc;
+ d = GGML_FP16_TO_FP32(x[i].d);
+ const uint16_t scales_h = x[i].scales_h;
+ const uint16_t * scales_l = (const uint16_t *)x[i].scales_l;
+ aux32[0] = scales_l[0] | (scales_l[1] << 16);
+ aux32[1] = aux32[0] >> 4;
+ // scl is ordered as 0, 2, 4, 6, 1, 3, 5, 7
+ uint8x8_t scl8 = vand_u8(vld1_u8((const uint8_t *)aux32), vdup_n_u8(0xf));
+ uint16_t * aux16 = (uint16_t *)aux32;
+ aux16[0] = scales_h << 4; aux16[1] = scales_h << 2; aux16[2] = scales_h; aux16[3] = scales_h >> 2;
+ // sch is ordered as 0, 4, 1, 5, 2, 6, 3, 7
+ uint8x8_t sch8 = vand_u8(vld1_u8((const uint8_t *)aux16), vdup_n_u8(0x30));
+ int8x8_t scales8 = vadd_s8(vreinterpret_s8_u8(vorr_u8(scl8, vtbl1_u8(sch8, vreinterpret_u8_u32(hshuff)))), vdup_n_s8(-32));
+ // shuffle 0, 2, 4, 6, 1, 3, 5, 7 -> 0, 1, 2, 3, 4, 5, 6, 7
+ scales8 = vtbl1_s8(scales8, vreinterpret_s8_u32(hshuff));
+ int16x8_t scales16 = vmovl_s8(scales8);
+ int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
+ return scales;
+ }
+ inline void prepare(int i, int j) {
+ bits.prepare16(x[i].qs+64*j);
+ //if (nrc == 1) {
+ // bits.prepare16_v2(x[i].qs+64*j);
+ //} else {
+ // bits.prepare16(x[i].qs+64*j);
+ //}
+ for (int k = 0; k < 4; ++k) {
+ bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b1.val[k]));
+ bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b2.val[k]));
+ }
+ }
+
+ Q4bits bits;
+ const int8x16_t values;
+ uint32_t aux32[2];
+
+ constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602};
+
+ float d;
+};
+
+struct SimpleBits {
+ uint8x16x4_t b1;
+ uint8x16x4_t b2;
+};
+
+inline int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) {
+ int32x4x2_t scales;
+ scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v1, 28), 1), vdupq_n_u32(1)));
+ scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v2, 28), 1), vdupq_n_u32(1)));
+ return scales;
+}
+
+inline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) {
+ auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127))));
+ auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >>14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >>21) & 127))));
+ b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1));
+ b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2));
+}
+
+struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
+ DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+
+ auto tmp = vld1q_u32_x4((const uint32_t *)x[i].qs);
+ data.val[0] = vuzp1q_u32(tmp.val[0], tmp.val[1]); // codebook indices for blocks 0...3
+ data.val[1] = vuzp2q_u32(tmp.val[0], tmp.val[1]); // scales and signs for blocks 0...3
+ data.val[2] = vuzp1q_u32(tmp.val[2], tmp.val[3]); // codebook indices for blocks 4...7
+ data.val[3] = vuzp2q_u32(tmp.val[2], tmp.val[3]); // scales and signs for blocks 4...7
+
+ return prepare_scales_8(data.val[1], data.val[3]);
+ }
+
+ static inline void prepare2(uint8x16_t * b, const uint8_t * idx, const uint64_t * signs, uint32_t sidx) {
+ b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]});
+ b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]});
+ apply_signs_2(b, signs, sidx);
+ }
+
+ inline void prepare(int /*i*/, int j) {
+ const uint8_t * idx = (const uint8_t *)(data.val + 2*j);
+ const uint32_t * sidx = (const uint32_t *)(data.val + 2*j+1);
+ prepare2(bits.b1.val + 0, idx, keven_signs, sidx[0]); idx += 4;
+ prepare2(bits.b1.val + 2, idx, keven_signs, sidx[1]); idx += 4;
+ prepare2(bits.b2.val + 0, idx, keven_signs, sidx[2]); idx += 4;
+ prepare2(bits.b2.val + 2, idx, keven_signs, sidx[3]);
+ }
+
+ uint32x4x4_t data;
+ SimpleBits bits;
+
+ float d;
+};
+
+inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) {
+ auto aux = vld1_u8(sc);
+ auto scales_l = vand_u8(aux, vdup_n_u8(0xf));
+ auto scales_h = vshr_n_u8(aux, 4);
+ auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
+
+ auto scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1)));
+ int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) };
+ return make_wider(scales16);
+}
+
+struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
+ DequantizerIQ2XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+ return prepare_4bit_scales16(x[i].scales);
+ }
+
+ inline static uint8x16_t make1(const uint16_t * qs) {
+ auto b = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_u8((const uint8_t *)(iq2xs_grid + (qs[1] & 511))));
+ auto s = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9))));
+ return vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b), s));
+ }
+
+ inline static void make4(const uint16_t * qs, uint8x16_t * b) {
+ b[0] = make1(qs + 0);
+ b[1] = make1(qs + 2);
+ b[2] = make1(qs + 4);
+ b[3] = make1(qs + 6);
+ }
+
+ inline void prepare(int i, int j) {
+ make4(x[i].qs + 16*j + 0, bits.b1.val);
+ make4(x[i].qs + 16*j + 8, bits.b2.val);
+ }
+
+ SimpleBits bits;
+
+ float d;
+
+};
+
+struct SignHelper {
+
+ inline void init() { shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); }
+
+ inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16) {
+ auto aux = vqtbl1q_u8(signs16, shuffle);
+ auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1));
+ b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s));
+ shuffle = vaddq_u8(shuffle, step);
+ }
+
+ const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
+ const uint8x16_t m1 = vdupq_n_u8(1);
+ const uint8x16_t step = vdupq_n_u8(2);
+ uint8x16_t shuffle;
+};
+
+struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
+ DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 16; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
+ d = 0.125f * GGML_FP16_TO_FP32(x[i].d);
+ return prepare_4bit_scales16(x[i].scales);
+ }
+
+ static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) {
+ uint32_t aux32[2];
+ const uint16_t * aux16 = (const uint16_t *)aux32;
+ for (int k = 0; k < 2; ++k) {
+ aux32[1] = (qh[k] << 4) | (qh[k] << 18);
+ aux32[0] = (aux32[1] << 4) & 0x03000300;
+ aux32[1] &= 0x03000300;
+ b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))),
+ vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1]))));
+ sh.apply_signs_1(b+2*k+0, signs16);
+
+ b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))),
+ vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3]))));
+ sh.apply_signs_1(b+2*k+1, signs16);
+ }
+ }
+
+ inline void prepare(int i, int j) {
+
+ const auto * qs = x[i].qs + 16*j;
+ const auto * qh = x[i].qh + 4*j;
+ const auto signs16 = vld1q_u8(qs + QK_K/8);
+
+ sh.init();
+ make4(sh, signs16, qs+0, qh+0, bits.b1.val);
+ make4(sh, signs16, qs+8, qh+2, bits.b2.val);
+ }
+
+ SimpleBits bits;
+ SignHelper sh;
+
+ float d;
+
+};
+
+struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
+ DequantizerIQ3XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
+ d = 0.25f * GGML_FP16_TO_FP32(x[i].d);
+ gas = vld1q_u32_x2((const uint32_t *)(x[i].qs + QK_K/4));
+ return prepare_scales_8(gas.val[0], gas.val[1]);
+ }
+
+ inline static void make2(const uint8_t * q3, uint32_t sidx, uint8x16_t * b) {
+ b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]});
+ b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]});
+ apply_signs_2(b, keven_signs, sidx);
+ }
+ inline void prepare(int i, int j) {
+ const auto * q3 = x[i].qs + 32*j;
+ const auto * signs = (const uint32_t *)(gas.val + j);
+ make2(q3, signs[0], bits.b1.val + 0); q3 += 8;
+ make2(q3, signs[1], bits.b1.val + 2); q3 += 8;
+ make2(q3, signs[2], bits.b2.val + 0); q3 += 8;
+ make2(q3, signs[3], bits.b2.val + 2);
+ }
+
+ SimpleBits bits;
+ uint32x4x2_t gas;
+
+ float d;
+
+};
+
+struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
+ DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
+
+ constexpr static int num_blocks() { return 8; }
+ constexpr static bool should_scale_quants() { return false; }
+
+ template <typename Q8>
+ inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
+ d = GGML_FP16_TO_FP32(x[i].d);
+ uint32_t scales32[2];
+ std::memcpy(scales32, x[i].scales, 4);
+ scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
+ scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
+ auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7
+ scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400)));
+ auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8));
+ int32x4x2_t scales;
+ scales.val[0] = vmovl_s16(vget_low_s16(scales16));
+ scales.val[1] = vmovl_s16(vget_high_s16(scales16));
+ return scales;
+ }
+
+ static inline void make2(SignHelper& sh, const uint8x16_t& signs16, const uint16x8_t& idx_l, uint8_t qh,
+ const int8x16_t& hshift, uint8x16_t * b) {
+ auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256)));
+ const uint16_t * idx = (const uint16_t *)&vindex;
+ b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});
+ b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});
+ sh.apply_signs_1(b+0, signs16);
+ sh.apply_signs_1(b+1, signs16);
+ }
+ static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh,
+ const int8x16_t& hshift, uint8x16_t * b) {
+ auto idx_l = vld1q_u8(qs);
+ make2(sh, signs16, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0);
+ make2(sh, signs16, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2);
+ }
+
+ inline void prepare(int i, int j) {
+
+ static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
+ const auto hshift = vld1q_s16(k_shift);
+
+ const auto * qs = x[i].qs + 32*j;
+ const auto * qh = x[i].qh + 4*j;
+ const auto signs16 = vld1q_u8(x[i].signs + 16*j);
+
+ sh.init();
+ make4(sh, signs16, qs+ 0, qh+0, hshift, bits.b1.val);
+ make4(sh, signs16, qs+16, qh+2, hshift, bits.b2.val);
+ }
+
+ SimpleBits bits;
+ SignHelper sh;
+ uint32x4x2_t gas;
+
+ float d;
+
+};
+
+
+template <int nrc_y, typename Dequantizer>
+void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ assert(n % QK_K == 0);
+ const int nb = n / QK_K;
+
+ Q8<nrc_y, block_q8_K> q8(info);
+
+ Dequantizer deq(vx, bx, nrc_y);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ deq.new_row(ix);
+
+ float32x4_t acc[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
+
+ for (int i = 0; i < nb; ++i) {
+
+ int32x4_t sumi[nrc_y];
+ for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);
+
+ if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) {
+ deq.process_scales(i, q8, acc);
+ deq.prepare(i, 0);
+ deq.compute(q8, i, 0, sumi);
+ deq.prepare(i, 1);
+ deq.compute(q8, i, 1, sumi);
+ } else {
+ if constexpr (Dequantizer::num_blocks() == 8) {
+ auto scales = deq.new_block(i, q8, acc);
+ deq.prepare(i, 0);
+ for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
+ deq.prepare(i, 1);
+ for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
+ }
+ else if constexpr (Dequantizer::num_blocks() == 16) {
+ auto scales = deq.new_block(i, q8, acc);
+ deq.prepare(i, 0);
+ for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
+ deq.prepare(i, 1);
+ for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
+ }
+ else {
+ GGML_ASSERT(false);
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vaddvq_f32(acc[iy]));
+ }
+ }
+}
+
+// =========================================== Legacy quants
+
+template <typename Block>
+inline float16x4_t load_scales_q0(const Block * x, ggml_half * aux) {
+ for (int k = 0; k < 4; ++k) aux[k] = x[k].d;
+ return vld1_f16((const float16_t *)aux);
+}
+
+template <typename Block>
+inline float16x8_t load_scales_q1(const Block * x, ggml_half * aux) {
+ if constexpr (std::is_same_v<Block, block_q8_1>) {
+ for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].s; }
+ } else {
+ for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].m; }
+ }
+ return vld1q_f16((const float16_t *)aux);
+}
+
+struct Q4LegacyBits {
+ template <typename Block>
+ inline void prepare(const Block * x) {
+ for (int i = 0; i < 4; ++i) {
+ auto q4bits = vld1q_u8(x[i].qs);
+ b[2*i+0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b));
+ b[2*i+1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4));
+ }
+ }
+ inline void prepare1(const uint8_t * qs, int8x16_t * q) const {
+ auto q4bits = vld1q_u8(qs);
+ q[0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b));
+ q[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4));
+ }
+ inline void prepare1(const uint8_t * qs) {
+ prepare1(qs, b);
+ }
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
+ int8x16_t b[8];
+};
+
+// One would think this commented out version would do better than the one below
+// because it offers more opportunities to execute instructions in parallel.
+// Instead, it runs significantly slower. Why? If the compiler is running out of vector registers
+// cannot it just do the sequential version below on its own?
+//inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) {
+// const auto q8b_1 = vld1q_s8_x2(qs + 0);
+// auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b_1.val[0]), b[1], q8b_1.val[1]);
+// const auto q8b_2 = vld1q_s8_x2(qs + 32);
+// auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b_2.val[0]), b[3], q8b_2.val[1]);
+// auto p1234 = vpaddq_s32(p12, p34);
+// const auto q8b_3 = vld1q_s8_x2(qs + 64);
+// auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b_3.val[0]), b[5], q8b_3.val[1]);
+// const auto q8b_4 = vld1q_s8_x2(qs + 96);
+// auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b_4.val[0]), b[7], q8b_4.val[1]);
+// return vpaddq_s32(p1234, vpaddq_s32(p56, p78));
+//}
+
+inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) {
+ auto q8b = vld1q_s8_x2(qs + 0);
+ auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b.val[0]), b[1], q8b.val[1]);
+ q8b = vld1q_s8_x2(qs + 32);
+ auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b.val[0]), b[3], q8b.val[1]);
+ auto p1234 = vpaddq_s32(p12, p34);
+ q8b = vld1q_s8_x2(qs + 64);
+ auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b.val[0]), b[5], q8b.val[1]);
+ q8b = vld1q_s8_x2(qs + 96);
+ auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b.val[0]), b[7], q8b.val[1]);
+ return vpaddq_s32(p1234, vpaddq_s32(p56, p78));
+}
+
+template <int nrc> struct Q80 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q80(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy);
+ }
+
+ inline const int8_t * quant_data(int iy, int i) const {
+ const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i;
+ return y4->qs;
+ }
+
+ inline float16x4_t load_scales(int iy, int i) const {
+ const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i;
+ return vld1_f16((const float16_t *)y4->d);
+ }
+
+ template <typename Dequantizer>
+ inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * /*acc*/) const {
+ auto qx_scales = deq.new_block(i);
+ for (int iy = 0; iy < nrc; ++iy) {
+ auto q8_scales = load_scales(iy, i);
+ sc16[iy] = vmul_f16(qx_scales, q8_scales);
+ }
+ }
+
+ template <typename Dequantizer>
+ inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {
+ deq.prepare1(i);
+ float d = GGML_FP16_TO_FP32(deq.x[i].d);
+ for (int iy = 0; iy < nrc; ++iy) {
+ auto q8b = vld1q_s8_x2(y[iy][i].qs);
+ auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]);
+ acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p));
+ }
+ }
+
+ const block_q8_0 * y[nrc_y];
+};
+
+template <int nrc> struct Q81 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q81(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_1 *)info.src1_row(iy);
+ }
+
+ inline const int8_t * quant_data(int iy, int i) const {
+ const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i;
+ return y4->qs;
+ }
+
+ inline float16x8_t load_scales(int iy, int i) const {
+ const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i;
+ return vld1q_f16((const float16_t *)y4->d);
+ }
+
+ template <typename Dequantizer>
+ inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * acc) const {
+ auto qx_scales = deq.new_block(i);
+ for (int iy = 0; iy < nrc; ++iy) {
+ auto q8_scales = load_scales(iy, i);
+ auto m = vmul_f16(vget_high_f16(qx_scales), vget_high_f16(q8_scales));
+ acc[iy] = vaddq_f32(acc[iy], vcvt_f32_f16(m));
+ sc16[iy] = vmul_f16(vget_low_f16(qx_scales), vget_low_f16(q8_scales));
+ }
+ }
+
+ template <typename Dequantizer>
+ inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {
+ deq.prepare1(i);
+ float d = GGML_FP16_TO_FP32(deq.x[i].d), m = 0.25f*GGML_FP16_TO_FP32(deq.x[i].m);
+ for (int iy = 0; iy < nrc; ++iy) {
+ auto q8b = vld1q_s8_x2(y[iy][i].qs);
+ auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]);
+ acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p));
+ acc[iy] = vaddq_f32(acc[iy], vdupq_n_f32(m*GGML_FP16_TO_FP32(y[iy][i].s)));
+ }
+ }
+
+ const block_q8_1 * y[nrc_y];
+};
+
+template <typename block_q>
+struct BaseLegacyDequantizer {
+
+ BaseLegacyDequantizer(const void * vx, size_t bx) : vx(vx), x(nullptr), bx(bx) {}
+
+ inline void new_row(int ix) { x = (const block_q *)((const char *)vx + bx*ix); }
+
+ Q4LegacyBits bits;
+
+ const void * vx;
+ const block_q * x;
+ size_t bx;
+};
+
+struct DequantizerQ40 final : public BaseLegacyDequantizer<block_q4_0> {
+
+ DequantizerQ40(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i, int8x16_t * q) const {
+ bits.prepare1(x[i].qs, q);
+ q[0] = vaddq_s8(q[0], m8);
+ q[1] = vaddq_s8(q[1], m8);
+ }
+ inline void prepare1(int i) {
+ prepare1(i, bits.b);
+ }
+
+ inline float16x4_t new_block(int i) {
+ ggml_half aux[4];
+ for (int k = 0; k < 4; ++k) {
+ aux[k] = x[4*i+k].d;
+ prepare1(4*i+k, bits.b + 2*k);
+ }
+ return vld1_f16((const float16_t *)aux);
+ }
+
+ const int8x16_t m8 = vdupq_n_s8(-8);
+ //ggml_half aux[4];
+};
+
+struct DequantizerIQ4NL final : public BaseLegacyDequantizer<block_iq4_nl> {
+
+ DequantizerIQ4NL(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i, int8x16_t * q) const {
+ bits.prepare1(x[i].qs, q);
+ q[0] = vqtbl1q_s8(values, q[0]);
+ q[1] = vqtbl1q_s8(values, q[1]);
+ }
+ inline void prepare1(int i) {
+ prepare1(i, bits.b);
+ }
+
+ inline float16x4_t new_block(int i) {
+ ggml_half aux[4];
+ for (int k = 0; k < 4; ++k) {
+ aux[k] = x[4*i+k].d;
+ prepare1(4*i+k, bits.b + 2*k);
+ }
+ return vld1_f16((const float16_t *)aux);
+ }
+ static int8x16_t load_values() {
+ static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+ return vld1q_s8(iq4nl_values);
+ }
+
+ const int8x16_t values = load_values();
+};
+
+struct DequantizerQ41 : public BaseLegacyDequantizer<block_q4_1> {
+
+ DequantizerQ41(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i) {
+ bits.prepare1(x[i].qs);
+ }
+
+ inline float16x8_t new_block(int i) {
+ uint32_t aux32[4];
+ const uint32_t * s32 = (const uint32_t *)&x[4*i].d;
+ for (int k = 0; k < 4; ++k) {
+ aux32[k] = *s32; s32 += sizeof(block_q4_1)/4;
+ bits.prepare1(x[4*i+k].qs, bits.b + 2*k);
+ }
+ return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle)));
+ }
+ // Leaving this commented out attempt to be reminded that I already tried this.
+ // It has basically the same performance as the version above.
+ //inline float16x8_t new_block(int i) {
+ // uint32x4_t scales = {};
+ // const block_q4_1 * xi = x + 4*i;
+ // const uint32_t * s32 = (const uint32_t *)&xi->d;
+ // scales = vsetq_lane_u32(*s32, scales, 0); s32 += sizeof(block_q4_1)/4;
+ // bits.prepare1(xi[0].qs, bits.b + 0);
+ // scales = vsetq_lane_u32(*s32, scales, 1); s32 += sizeof(block_q4_1)/4;
+ // bits.prepare1(xi[1].qs, bits.b + 2);
+ // scales = vsetq_lane_u32(*s32, scales, 2); s32 += sizeof(block_q4_1)/4;
+ // bits.prepare1(xi[2].qs, bits.b + 4);
+ // scales = vsetq_lane_u32(*s32, scales, 3);
+ // bits.prepare1(xi[3].qs, bits.b + 6);
+ // return vreinterpretq_f16_u8(vqtbl1q_u8(vreinterpretq_u8_u32(scales), vreinterpretq_u8_u64(shuffle)));
+ //}
+
+ const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302};
+};
+
+struct HighBit5Legacy {
+ inline uint8x16_t to_bytes(const uint8_t * qh) const {
+ uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle);
+ return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vreinterpretq_u8_u64(mask));
+ }
+ inline uint8x16_t to_negated_bytes(const uint8_t * qh) const {
+ uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle);
+ return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vdupq_n_u8(0));
+ }
+ const uint64x2_t mask = vdupq_n_u64(0x8040201008040201);
+ const uint8x16_t shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1));
+};
+
+struct DequantizerQ50 final : public BaseLegacyDequantizer<block_q5_0> {
+
+ DequantizerQ50(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i, int8x16_t * q) const {
+ bits.prepare1(x[i].qs, q);
+ auto qh = x[i].qh;
+ q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0))));
+ q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2))));
+ }
+ inline void prepare1(int i) {
+ prepare1(i, bits.b);
+ }
+
+ inline float16x4_t new_block(int i) {
+ ggml_half aux[4];
+ for (int k = 0; k < 4; ++k) {
+ aux[k] = x[4*i+k].d;
+ prepare1(4*i+k, bits.b + 2*k);
+ }
+ return vld1_f16((const float16_t *)aux);
+ }
+
+ HighBit5Legacy hbits;
+
+ const uint8x16_t mh = vdupq_n_u8(0xf0);
+
+};
+
+struct DequantizerQ80 final : public BaseLegacyDequantizer<block_q8_0> {
+
+ DequantizerQ80(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i) {
+ bits.b[0] = vld1q_s8(x[i].qs);
+ bits.b[1] = vld1q_s8(x[i].qs+16);
+ }
+
+ inline float16x4_t new_block(int i) {
+ ggml_half aux[4];
+ for (int k = 0; k < 4; ++k) {
+ aux[k] = x[4*i+k].d;
+ bits.b[2*k+0] = vld1q_s8(x[4*i+k].qs);
+ bits.b[2*k+1] = vld1q_s8(x[4*i+k].qs+16);
+ }
+ return vld1_f16((const float16_t *)aux);
+ }
+
+};
+
+struct DequantizerQ51 final : public BaseLegacyDequantizer<block_q5_1> {
+
+ DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
+
+ inline void prepare1(int i, int8x16_t * q) const {
+ bits.prepare1(x[i].qs, q);
+ auto qh = x[i].qh;
+ q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_bytes(qh+0))));
+ q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_bytes(qh+2))));
+ }
+ inline void prepare1(int i) {
+ bits.prepare1(x[i].qs, bits.b);
+ }
+
+ inline float16x8_t new_block(int i) {
+ uint32_t aux32[4];
+ const uint32_t * s32 = (const uint32_t *)&x[4*i].d;
+ for (int k = 0; k < 4; ++k) {
+ aux32[k] = *s32; s32 += sizeof(block_q5_1)/4;
+ prepare1(4*i+k, bits.b + 2*k);
+ }
+ return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle)));
+ }
+
+ HighBit5Legacy hbits;
+
+ const uint8x16_t mh = vdupq_n_u8(0x10);
+ const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302};
+
+};
+
+template <typename Dequantizer, typename Q8>
+inline void sum_4(int i, Dequantizer& deq, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) {
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ auto pall = sum_4_blocks(deq.bits.b, q8.quant_data(iy, i));
+ auto scale = vcvt_f32_f16(sc16[iy]);
+ acc[iy] = vmlaq_f32(acc[iy], scale, vcvtq_f32_s32(pall));
+ }
+}
+
+template <typename Dequantizer, typename Q8>
+inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK4_1;
+
+ float16x4_t sc16[Q8::nrc_y];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ deq.new_row(ix);
+
+ float32x4_t acc[Q8::nrc_y];
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
+
+ for (int i = 0; i < nb/4; ++i) {
+ q8.process_scales(i, deq, sc16, acc);
+ sum_4(i, deq, q8, sc16, acc);
+ }
+ for (int i = 4*(nb/4); i < nb; ++i) {
+ q8.process_1_block(i, deq, acc);
+ }
+
+ for (int iy = 0; iy < Q8::nrc_y; ++iy) {
+ info.store(ix, iy, vaddvq_f32(acc[iy]));
+ }
+ }
+}
+
+template <typename Dequantizer, typename Q8>
+inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK4_1;
+
+ float16x4_t sc16[2];
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ deq1.new_row(ix);
+ deq2.new_row(ix);
+
+ float32x4_t acc[2] = { vdupq_n_f32(0.f), vdupq_n_f32(0.f) };
+
+ for (int i = 0; i < nb/8; ++i) {
+ q8.process_scales(2*i+0, deq1, sc16+0, acc+0);
+ q8.process_scales(2*i+1, deq2, sc16+1, acc+1);
+ sum_4(2*i+0, deq1, q8, sc16+0, acc+0);
+ sum_4(2*i+1, deq2, q8, sc16+1, acc+1);
+ }
+ for (int i = 2*(nb/8); i < nb/4; ++i) {
+ q8.process_scales(i, deq1, sc16, acc);
+ sum_4(i, deq1, q8, sc16, acc);
+ }
+ for (int i = 4*(nb/4); i < nb; ++i) {
+ q8.process_1_block(i, deq1, acc);
+ }
+
+ info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1])));
+ }
+}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ Q81<nrc_y> q8(info);
+ if constexpr (nrc_y == 1) {
+ Dequantizer deq1(vx, bx), deq2(vx, bx);
+ mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);
+ } else {
+ Dequantizer deq(vx, bx);
+ mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);
+ }
+}
+
+template <typename Dequantizer, int nrc_y>
+static void mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ Q80<nrc_y> q8(info);
+ if constexpr (nrc_y == 1) {
+ Dequantizer deq1(vx, bx), deq2(vx, bx);
+ mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);
+ } else {
+ Dequantizer deq(vx, bx);
+ mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);
+ }
+}
+
+template <typename Dequantizer>
+static void mul_mat_qX_1_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ Dequantizer deq1(vx, bx), deq2(vx, bx);
+ Q81<1> q8(info);
+ mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);
+}
+
+template <typename Dequantizer>
+static void mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ Dequantizer deq1(vx, bx), deq2(vx, bx);
+ Q80<1> q8(info);
+ mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x);
+}
+
+struct QF16Base {
+ constexpr static int k_step = 8;
+ using Data = float16x8_t;
+ using Acc = float16x8_t;
+ static inline Data load(const __fp16 * x) { return vld1q_f16(x); }
+ static inline Data load4(const __fp16 * x) { return vcombine_f16(vld1_f16(x), vdup_n_f16(0)); }
+ static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ return vfmaq_f16(prev, y, x);
+ }
+ static inline Acc acc_first(const Data& y, const Data& x) {
+ return vmulq_f16(y, x);
+ }
+ //constexpr static int k_step = 16;
+ //using Data = float16x8x2_t;
+ //static inline Data load(const __fp16 * x) { return vld1q_f16_x2(x); }
+ //static inline Acc acc(Acc prev, const Data& y, const Data& x) {
+ // return vfmaq_f16(vfmaq_f16(prev, y.val[0], x.val[0]), y.val[1], x.val[1]);
+ //}
+ //static inline Acc acc_first(const Data& y, const Data& x) {
+ // return vfmaq_f16(vmulq_f16(y.val[0], x.val[0]), y.val[1], x.val[1]);
+ //}
+ static inline float hsum(Acc acc) {
+ float32x4_t sum = vcvt_f32_f16(vadd_f16(vget_low_f16(acc), vget_high_f16(acc)));
+ return vaddvq_f32(sum);
+ }
+};
+template <int nrc> struct QF16 final : public QF16Base {
+ constexpr static int nrc_y = nrc;
+ QF16(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)info.src1_row(iy);
+ }
+ QF16(const char * cx, size_t bx) {
+ for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)(cx + iy*bx);
+ }
+ IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
+ IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4(y[iy] + 4*i); }
+ IQK_ALWAYS_INLINE float16x8x4_t loadx(int iy, int i) const { return vld1q_f16_x4(y[iy] + 4*k_step*i); }
+ const __fp16 * y[nrc_y];
+};
+
+template <int nrc_y, int nrc_x, bool is_multiple_of_k_step>
+IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+ assert(n%QF16Base::k_step == 0);
+ int nb = n/QF16Base::k_step;
+ QF16<nrc_y> y(info);
+ QF16<nrc_x> x(cx + ix0*bx, bx);
+ QF16Base::Data xv[nrc_x];
+ QF16Base::Acc acc[nrc_x*nrc_y];
+ auto yv = y.load1(0, 0);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ xv[ix] = x.load1(ix, 0);
+ acc[ix] = QF16Base::acc_first(yv, xv[ix]);
+ }
+ for (int iy = 1; iy < nrc_y; ++iy) {
+ yv = y.load1(iy, 0);
+ for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc_first(yv, xv[ix]);
+ }
+ for (int i = 1; i < nb; ++i) {
+ yv = y.load1(0, i);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ xv[ix] = x.load1(ix, i);
+ acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < nrc_y; ++iy) {
+ yv = y.load1(iy, i);
+ for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
+ }
+ }
+ if constexpr (!is_multiple_of_k_step) {
+ int nb4 = n/4;
+ for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) {
+ yv = y.load_tail(0, i);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ xv[ix] = x.load_tail(ix, i);
+ acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]);
+ }
+ for (int iy = 1; iy < nrc_y; ++iy) {
+ yv = y.load_tail(iy, i);
+ for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
+ }
+ }
+ }
+ for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QF16Base::hsum(acc[nrc_x*iy+ix]));
+}
+
+template <int nrc_y>
+void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%4 == 0);
+ constexpr int k_nx = 5;
+ const char * cx = (const char *)vx;
+ if (n%QF16Base::k_step == 0) {
+ for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ mul_mat_f16_f16_NxN<nrc_y, k_nx, true>(n, cx, bx, ix*k_nx, info);
+ }
+ int last_x = k_nx*(nrc_x/k_nx);
+ if (last_x == nrc_x) return;
+ int nx = nrc_x - last_x;
+ switch (nx) {
+ case 1: mul_mat_f16_f16_NxN<nrc_y, 1, true>(n, cx, bx, last_x, info); break;
+ case 2: mul_mat_f16_f16_NxN<nrc_y, 2, true>(n, cx, bx, last_x, info); break;
+ case 3: mul_mat_f16_f16_NxN<nrc_y, 3, true>(n, cx, bx, last_x, info); break;
+ case 4: mul_mat_f16_f16_NxN<nrc_y, 4, true>(n, cx, bx, last_x, info); break;
+ }
+ } else {
+ for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+ mul_mat_f16_f16_NxN<nrc_y, k_nx, false>(n, cx, bx, ix*k_nx, info);
+ }
+ int last_x = k_nx*(nrc_x/k_nx);
+ if (last_x == nrc_x) return;
+ int nx = nrc_x - last_x;
+ switch (nx) {
+ case 1: mul_mat_f16_f16_NxN<nrc_y, 1, false>(n, cx, bx, last_x, info); break;
+ case 2: mul_mat_f16_f16_NxN<nrc_y, 2, false>(n, cx, bx, last_x, info); break;
+ case 3: mul_mat_f16_f16_NxN<nrc_y, 3, false>(n, cx, bx, last_x, info); break;
+ case 4: mul_mat_f16_f16_NxN<nrc_y, 4, false>(n, cx, bx, last_x, info); break;
+ }
+ }
+}
+
+template <int nrc_x, bool is_multiple_of_k_step>
+IQK_NOINLINE void mul_mat_f16_f16_Nx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
+ assert(n%QF16Base::k_step == 0);
+ int nb = n/QF16Base::k_step;
+ QF16<1> y(info);
+ QF16<nrc_x> x(cx + ix0*bx, bx);
+ QF16Base::Acc acc[4*nrc_x];
+ auto yv = y.loadx(0, 0);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ for (int k = 0; k < 4; ++k) {
+ auto xv = x.load1(ix, k);
+ acc[4*ix+k] = QF16Base::acc_first(yv.val[k], xv);
+ }
+ }
+ for (int i = 1; i < nb/4; ++i) {
+ yv = y.loadx(0, i);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ for (int k = 0; k < 4; ++k) {
+ auto xv = x.load1(ix, 4*i+k);
+ acc[4*ix+k] = QF16Base::acc(acc[4*ix+k], yv.val[k], xv);
+ }
+ }
+ }
+ for (int i = 4*(nb/4); i < nb; ++i) {
+ auto yv1 = y.load1(0, i);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto xv1 = x.load1(ix, i);
+ acc[4*ix] = QF16Base::acc(acc[4*ix], yv1, xv1);
+ }
+ }
+ if constexpr (!is_multiple_of_k_step) {
+ int nb4 = n/4;
+ for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) {
+ auto yv1 = y.load_tail(0, i);
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto xv1 = x.load_tail(ix, i);
+ acc[4*ix] = QF16Base::acc(acc[4*ix], yv1, xv1);
+ }
+ }
+ }
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ auto v1 = vaddq_f16(acc[4*ix+0], acc[4*ix+1]);
+ auto v2 = vaddq_f16(acc[4*ix+2], acc[4*ix+3]);
+ info.store(ix0+ix, 0, QF16Base::hsum(vaddq_f16(v1, v2)));
+ }
+}
+
+// At least on my M2-Max the version below, which dows the multiplication row-by-row, is faster.
+// But let's keep this version commented out for now.
+//void mul_mat_f16_f16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+// GGML_ASSERT(n%4 == 0);
+// constexpr int k_nx = 2;
+// const char * cx = (const char *)vx;
+// if (n%QF16Base::k_step == 0) {
+// for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+// mul_mat_f16_f16_Nx1<k_nx, true>(n, cx, bx, ix*k_nx, info);
+// }
+// int last_x = k_nx*(nrc_x/k_nx);
+// if (last_x == nrc_x) return;
+// int nx = nrc_x - last_x;
+// switch (nx) {
+// case 1: mul_mat_f16_f16_Nx1<1, true>(n, cx, bx, last_x, info); break;
+// //case 2: mul_mat_f16_f16_Nx1<2, true>(n, cx, bx, last_x, info); break;
+// //case 3: mul_mat_f16_f16_Nx1<3, true>(n, cx, bx, last_x, info); break;
+// }
+// } else {
+// for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
+// mul_mat_f16_f16_Nx1<k_nx, false>(n, cx, bx, ix*k_nx, info);
+// }
+// int last_x = k_nx*(nrc_x/k_nx);
+// if (last_x == nrc_x) return;
+// int nx = nrc_x - last_x;
+// switch (nx) {
+// case 1: mul_mat_f16_f16_Nx1<1, false>(n, cx, bx, last_x, info); break;
+// //case 2: mul_mat_f16_f16_Nx1<2, false>(n, cx, bx, last_x, info); break;
+// //case 3: mul_mat_f16_f16_Nx1<3, false>(n, cx, bx, last_x, info); break;
+// }
+// }
+//}
+
+void mul_mat_f16_f16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ GGML_ASSERT(n%4 == 0);
+ const char * cx = (const char *)vx;
+ if (n%QF16Base::k_step == 0) {
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ mul_mat_f16_f16_Nx1<1, true>(n, cx, bx, ix, info);
+ }
+ } else {
+ for (int ix = 0; ix < nrc_x; ++ix) {
+ mul_mat_f16_f16_Nx1<1, false>(n, cx, bx, ix, info);
+ }
+ }
+}
+
+template <int nrc> struct Q8_K64 {
+
+ constexpr static int nrc_y = nrc;
+
+ Q8_K64(const DataInfo& info) {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto dptr = (const float *)info.src1_row(iy);
+ std::memcpy(d + 4*iy, dptr, 4*sizeof(float));
+ y[iy] = (const int8_t *)(dptr + 4);
+ }
+ }
+
+ inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy] + 128*i + 64*j); }
+ inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy] + 128*i + 32*j); }
+ inline float32x4_t scale(int iy) const { return vld1q_f32(d + 4*iy); }
+
+ float d[4*nrc_y];
+ const int8_t * y[nrc_y];
+};
+
+struct DequantizerIQ1BN {
+ const uint8x16_t m1 = vdupq_n_u8(1);
+
+ static inline uint8x16x4_t load_shuffles() {
+ static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12,
+ 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12,
+ 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12,
+ 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12};
+ return vld1q_u8_x4(data);
+ }
+ static inline uint8x16x4_t load_mult() {
+ static const uint8_t data[64] = {81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81,
+ 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 27,
+ 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 9,
+ 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 3};
+ return vld1q_u8_x4(data);
+ }
+ const uint8x16x4_t shuff = load_shuffles();
+ const uint8x16x4_t mult = load_mult();
+
+ IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
+ auto data = vld1q_u8((const uint8_t *)x);
+ for (int k = 0; k < 4; ++k) {
+ auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
+ val = vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6);
+ v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1);
+ }
+ }
+};
+
+template <int nrc_y>
+static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK_IQ1BN;
+
+ Q8_K64<nrc_y> q8(info);
+ DequantizerIQ1BN deq;
+
+ int32x4_t accd[nrc_y];
+ int8x16x4_t v1, v2;
+
+ const block_iq1_bn * x = (const block_iq1_bn *)((const char *)vx);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ x = (const block_iq1_bn *)((const char *)vx + ix*bx);
+
+ if constexpr (nrc_y == 1) {
+ int32x4_t acc[4] = {};
+ for (int i = 0; i < nb/2; ++i) {
+ deq.prepare_iq1bn_quants(x+2*i+0, v1);
+ auto q = q8.load_quants64(0, i, 0);
+ for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
+ deq.prepare_iq1bn_quants(x+2*i+1, v2);
+ q = q8.load_quants64(0, i, 1);
+ for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v2.val[j]);
+ }
+ accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3]));
+ }
+ else {
+
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0);
+
+ for (int i = 0; i < nb/2; ++i) {
+
+ deq.prepare_iq1bn_quants(x+2*i+0, v1);
+ deq.prepare_iq1bn_quants(x+2*i+1, v2);
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto q = q8.load_quants(iy, i, 0);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
+ q = q8.load_quants(iy, i, 1);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
+ q = q8.load_quants(iy, i, 2);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]);
+ q = q8.load_quants(iy, i, 3);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]);
+ }
+ }
+ }
+ int i = 2*(nb/2);
+ if (i < nb) {
+ deq.prepare_iq1bn_quants(x+i, v1);
+ if constexpr (nrc_y == 1) {
+ auto q = q8.load_quants(0, i/2, 0);
+ for (int j = 0; j < 4; ++j) {
+ accd[0] = ggml_vdotq_s32(accd[0], q.val[j], v1.val[j]);
+ }
+ } else {
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto q = q8.load_quants(iy, i/2, 0);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
+ q = q8.load_quants(iy, i/2, 1);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
+ }
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
+ }
+
+ }
+}
+
+template <int nrc_y>
+static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
+ const int nb = n / QK_IQ1BN;
+
+ Q8_K64<nrc_y> q8(info);
+
+ int32x4_t accd[nrc_y];
+
+ const auto m1 = vdupq_n_u8(1);
+ const auto mask2 = vdupq_n_s8(3);
+
+ for (int ix = 0; ix < nrc_x; ++ix) {
+
+ const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx);
+
+ if constexpr (nrc_y == 1) {
+ int8x16x4_t v1;
+ int32x4_t acc[4] = {};
+ for (int i = 0; i < nb/2; ++i) {
+ for (int j = 0; j < 2; ++j) {
+ auto q = q8.load_quants64(0, i, j);
+ auto q2bits = vld1q_u8(x[2*i+j].qs);
+ v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
+ v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
+ v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
+ v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
+ acc[0] = ggml_vdotq_s32(acc[0], q.val[0], v1.val[0]);
+ acc[1] = ggml_vdotq_s32(acc[1], q.val[1], v1.val[1]);
+ acc[2] = ggml_vdotq_s32(acc[2], q.val[2], v1.val[2]);
+ acc[3] = ggml_vdotq_s32(acc[3], q.val[3], v1.val[3]);
+ }
+ }
+ accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3]));
+ } else {
+ int8x16x4_t v1, v2;
+ for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0);
+ for (int i = 0; i < nb/2; ++i) {
+ auto q2bits = vld1q_u8(x[2*i+0].qs);
+ v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
+ v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
+ v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
+ v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
+ q2bits = vld1q_u8(x[2*i+1].qs);
+ v2.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
+ v2.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
+ v2.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
+ v2.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto q = q8.load_quants(iy, i, 0);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
+ q = q8.load_quants(iy, i, 1);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
+ q = q8.load_quants(iy, i, 2);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]);
+ q = q8.load_quants(iy, i, 3);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]);
+ }
+ }
+ }
+ int i = 2*(nb/2);
+ if (i < nb) {
+ auto q2bits = vld1q_u8(x[i].qs);
+ int8x16x4_t v1;
+ v1.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
+ v1.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
+ v1.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
+ v1.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ auto q = q8.load_quants(iy, i/2, 0);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
+ q = q8.load_quants(iy, i/2, 1);
+ accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
+ }
+ }
+
+ for (int iy = 0; iy < nrc_y; ++iy) {
+ info.store(ix, iy, vaddvq_f32(vmulq_f32(q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
+ }
+ }
+}
+
+template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
+ if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> ||
+ std::is_same_v<Dequantizer, DequantizerQ80> || std::is_same_v<Dequantizer, DequantizerIQ4NL>) {
+ m.funcs[0] = mul_mat_qX_0_q8_0<Dequantizer, 1>;
+ m.funcs[1] = mul_mat_qX_0_q8_0<Dequantizer, 2>;
+ m.funcs[2] = mul_mat_qX_0_q8_0<Dequantizer, 3>;
+ m.funcs[3] = mul_mat_qX_0_q8_0<Dequantizer, 4>;
+ m.funcs[4] = mul_mat_qX_0_q8_0<Dequantizer, 5>;
+ m.funcs[5] = mul_mat_qX_0_q8_0<Dequantizer, 6>;
+ m.funcs[6] = mul_mat_qX_0_q8_0<Dequantizer, 7>;
+ m.funcs[7] = mul_mat_qX_0_q8_0<Dequantizer, 8>;
+ }
+ else if constexpr (std::is_same_v<Dequantizer, DequantizerQ41> || std::is_same_v<Dequantizer, DequantizerQ51>) {
+ m.funcs[0] = mul_mat_qX_1_q8_1<Dequantizer, 1>;
+ m.funcs[1] = mul_mat_qX_1_q8_1<Dequantizer, 2>;
+ m.funcs[2] = mul_mat_qX_1_q8_1<Dequantizer, 3>;
+ m.funcs[3] = mul_mat_qX_1_q8_1<Dequantizer, 4>;
+ m.funcs[4] = mul_mat_qX_1_q8_1<Dequantizer, 5>;
+ m.funcs[5] = mul_mat_qX_1_q8_1<Dequantizer, 6>;
+ m.funcs[6] = mul_mat_qX_1_q8_1<Dequantizer, 7>;
+ m.funcs[7] = mul_mat_qX_1_q8_1<Dequantizer, 8>;
+ }
+ else {
+ m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>;
+ m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>;
+ m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>;
+ m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>;
+ m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>;
+ m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>;
+ m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>;
+ m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>;
+ }
+}
+
+bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
+
+ if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F16) {
+ if (ne00%4) return false;
+ for (auto& f : m.funcs) f = nullptr;
+ m.funcs[0] = mul_mat_f16_f16_1;
+ m.funcs[1] = mul_mat_f16_f16_T<2>;
+ m.funcs[2] = mul_mat_f16_f16_T<3>;
+ m.funcs[3] = mul_mat_f16_f16_T<4>;
+ m.funcs[4] = mul_mat_f16_f16_T<5>;
+ return true;
+ }
+
+ auto expected_Btype = GGML_TYPE_Q8_K;
+
+ switch (typeA) {
+ case GGML_TYPE_Q2_K:
+ MulMat::set_functions<DequantizerQ2K>(m);
+ break;
+ case GGML_TYPE_Q3_K:
+ MulMat::set_functions<DequantizerQ3K>(m);
+ break;
+ case GGML_TYPE_Q4_K:
+ MulMat::set_functions<DequantizerQ4K>(m);
+ break;
+ case GGML_TYPE_Q5_K:
+ MulMat::set_functions<DequantizerQ5K>(m);
+ break;
+ case GGML_TYPE_Q6_K:
+ MulMat::set_functions<DequantizerQ6K>(m);
+ break;
+ case GGML_TYPE_IQ4_XS:
+ MulMat::set_functions<DequantizerIQ4XS>(m);
+ break;
+ case GGML_TYPE_IQ2_XXS:
+ MulMat::set_functions<DequantizerIQ2XXS>(m);
+ break;
+ case GGML_TYPE_IQ2_XS:
+ MulMat::set_functions<DequantizerIQ2XS>(m);
+ break;
+ case GGML_TYPE_IQ2_S:
+ MulMat::set_functions<DequantizerIQ2S>(m);
+ break;
+ case GGML_TYPE_IQ3_XXS:
+ MulMat::set_functions<DequantizerIQ3XXS>(m);
+ break;
+ case GGML_TYPE_IQ3_S:
+ MulMat::set_functions<DequantizerIQ3S>(m);
+ break;
+ case GGML_TYPE_IQ1_BN:
+ m.funcs[0] = mul_mat_iq1bn_q8_K64<1>;
+ m.funcs[1] = mul_mat_iq1bn_q8_K64<2>;
+ m.funcs[2] = mul_mat_iq1bn_q8_K64<3>;
+ m.funcs[3] = mul_mat_iq1bn_q8_K64<4>;
+ m.funcs[4] = mul_mat_iq1bn_q8_K64<5>;
+ m.funcs[5] = mul_mat_iq1bn_q8_K64<6>;
+ m.funcs[6] = mul_mat_iq1bn_q8_K64<7>;
+ m.funcs[7] = mul_mat_iq1bn_q8_K64<8>;
+ expected_Btype = GGML_TYPE_Q8_K64;
+ break;
+ case GGML_TYPE_IQ2_BN:
+ m.funcs[0] = mul_mat_iq2bn_q8_K64<1>;
+ m.funcs[1] = mul_mat_iq2bn_q8_K64<2>;
+ m.funcs[2] = mul_mat_iq2bn_q8_K64<3>;
+ m.funcs[3] = mul_mat_iq2bn_q8_K64<4>;
+ m.funcs[4] = mul_mat_iq2bn_q8_K64<5>;
+ m.funcs[5] = mul_mat_iq2bn_q8_K64<6>;
+ m.funcs[6] = mul_mat_iq2bn_q8_K64<7>;
+ m.funcs[7] = mul_mat_iq2bn_q8_K64<8>;
+ expected_Btype = GGML_TYPE_Q8_K64;
+ break;
+ case GGML_TYPE_Q4_0:
+ MulMat::set_functions<DequantizerQ40>(m);
+ expected_Btype = GGML_TYPE_Q8_0;
+ break;
+ case GGML_TYPE_Q4_1:
+ MulMat::set_functions<DequantizerQ41>(m);
+ expected_Btype = GGML_TYPE_Q8_1;
+ break;
+ case GGML_TYPE_Q5_0:
+ MulMat::set_functions<DequantizerQ50>(m);
+ expected_Btype = GGML_TYPE_Q8_0;
+ break;
+ case GGML_TYPE_Q5_1:
+ MulMat::set_functions<DequantizerQ51>(m);
+ expected_Btype = GGML_TYPE_Q8_1;
+ break;
+ case GGML_TYPE_Q8_0:
+ MulMat::set_functions<DequantizerQ80>(m);
+ expected_Btype = GGML_TYPE_Q8_0;
+ break;
+ case GGML_TYPE_IQ4_NL:
+ MulMat::set_functions<DequantizerIQ4NL>(m);
+ expected_Btype = GGML_TYPE_Q8_0;
+ break;
+ default:
+ return false;
+ }
+
+ return typeB == expected_Btype;
+}
+
+}
+
+#endif // __aarch64__
+
+#else // IQK_IMPLEMENT
+
+bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) {
+ return false;
+}
+
+bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const void *, long, float *, long, long,
+ const void *, int, int) {
+ return false;
+}
+
+#endif
diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h
new file mode 100644
index 00000000..6bed5f5a
--- /dev/null
+++ b/ggml/src/iqk/iqk_mul_mat.h
@@ -0,0 +1,27 @@
+//
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+#include <stdint.h>
+#include <stdbool.h>
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+bool iqk_mul_mat(long Nx, long Ny, long ne00,
+ int typeA, const void * A, long strideA,
+ int typeB, const void * B, long strideB,
+ float * C, long stride_C, int ith, int nth);
+
+bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
+ int typeA, const void * A, long strideA,
+ int typeB, const void * B, long strideB,
+ float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
+
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
new file mode 100644
index 00000000..8f541565
--- /dev/null
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -0,0 +1,414 @@
+//
+// Copyright (C) 2024 Iwan Kawrakow
+// MIT license
+// SPDX-License-Identifier: MIT
+//
+
+#if GGML_USE_IQK_MULMAT
+#include "iqk_mul_mat.h"
+#endif
+#include "ggml-quants.h"
+#include "ggml-impl.h"
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#include <vector>
+#include <utility>
+#include <cstdint>
+#include <cmath>
+#include <array>
+#include <algorithm>
+#include <cstring>
+
+namespace {
+
+inline int nearest_int(float fval) {
+ assert(fval <= 4194303.f);
+ float val = fval + 12582912.f;
+ int i; memcpy(&i, &val, sizeof(int));
+ return (i & 0x007fffff) - 0x00400000;
+}
+
+struct IQ1BNQuantizer {
+ int8_t L[QK_IQ1BN];
+ void quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix);
+ void quantize_one_row_2bn(const float * src, block_iq2_bn * y, int n_per_row, const float * imatrix);
+ static inline float row_max(int n_per_row, const float * src) {
+ float max_in_row = 0;
+ for (int j = 0; j < n_per_row; ++j) {
+ float ax = fabsf(src[j]);
+ max_in_row = std::max(max_in_row, ax);
+ }
+ return max_in_row;
+ }
+ static constexpr uint8_t k_mult[5] = {81, 27, 9, 3, 1};
+};
+
+void IQ1BNQuantizer::quantize_one_row_1bn(const float * src, block_iq1_bn * y, int n_per_row, const float * imatrix) {
+
+ static const int k_nb[6] = {1, 3, 9, 27, 81, 243};
+ (void)imatrix;
+
+ const int nblock = n_per_row/QK_IQ1BN;
+
+ for (int ib = 0; ib < nblock; ++ib) {
+ std::memset(&y[ib], 0, sizeof(block_iq1_bn));
+ auto xb = src + ib*QK_IQ1BN;
+ int v13 = 0;
+ for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) {
+ for (int k = 0; k < 3; ++k) {
+ int idx = 0;
+ for (int j = 0; j < 5; ++j) {
+ float v = xb[16*i16 + 5*k + j];
+ int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2;
+ idx += k_nb[j]*q;
+ }
+ idx = (256*idx + k_nb[5] - 1)/k_nb[5];
+ y[ib].ql[3*i16 + k] = idx;
+ }
+ float v = xb[16*i16 + 15];
+ int q = fabsf(v) < 1e-6f ? 1 : v < 0 ? 0 : 2;
+ v13 += k_nb[i16]*q;
+ }
+ y[ib].extra = (256*v13 + k_nb[5] - 1)/k_nb[5];
+ }
+}
+
+void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, int n_per_row, const float * imatrix) {
+
+ (void)imatrix;
+
+ const int nblock = n_per_row/QK_IQ1BN;
+
+ constexpr int Nj = QK_IQ1BN/4;
+
+ for (int ib = 0; ib < nblock; ++ib) {
+ auto xb = src + QK_IQ1BN*ib;
+ for (int j = 0; j < QK_IQ1BN; ++j) {
+ L[j] = fabsf(xb[j]) < 1e-6f ? 1 : xb[j] < 0 ? 0 : 2;
+ }
+ for (int j = 0; j < Nj; ++j) {
+ y[ib].qs[j] = L[j] | (L[j + Nj] << 2) | (L[j + 2*Nj] << 4) | (L[j + 3*Nj] << 6);
+ }
+ }
+}
+
+}
+
+size_t quantize_iq1_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ IQ1BNQuantizer iq1bn;
+ int nblock = n_per_row/QK_IQ1BN;
+ block_iq1_bn * y = (block_iq1_bn *)dst;
+ for (int row = 0; row < nrows; ++row) {
+ iq1bn.quantize_one_row_1bn(src + row*n_per_row, y, n_per_row, imatrix);
+ y += nblock;
+ }
+ return sizeof(block_iq1_bn)*nblock*nrows;
+}
+
+void quantize_row_iq1_bn_ref(const float * x, block_iq1_bn * y, int64_t k) {
+ quantize_iq1_bn(x, y, 1, k, nullptr);
+}
+
+void quantize_row_iq1_bn(const float * x, void * y, int64_t k) {
+ quantize_iq1_bn(x, y, 1, k, nullptr);
+}
+
+void dequantize_row_iq1_bn(const block_iq1_bn * x, float * y, int64_t k) {
+ assert(k%QK_IQ1BN == 0);
+ int nblock = k / QK_IQ1BN;
+
+ for (int i = 0; i < nblock; ++i) {
+ uint8_t extra = x[i].extra;
+ auto ql = x[i].ql;
+ for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) {
+ for (int k = 0; k < 3; ++k) {
+ for (int j = 0; j < 5; ++j) {
+ uint8_t v = ql[k]*IQ1BNQuantizer::k_mult[j];
+ int8_t vs = ((v + (v >> 1)) >> 7);
+ *y++ = vs - 1;
+ }
+ }
+ ql += 3;
+ uint8_t v = extra*IQ1BNQuantizer::k_mult[i16];
+ int8_t vs = ((v + (v >> 1)) >> 7);
+ *y++ = vs - 1;
+ }
+ }
+}
+
+size_t quantize_iq2_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) {
+ IQ1BNQuantizer iq1bn;
+ int nblock = n_per_row/QK_IQ1BN;
+ block_iq2_bn * y = (block_iq2_bn *)dst;
+ for (int row = 0; row < nrows; ++row) {
+ iq1bn.quantize_one_row_2bn(src + row*n_per_row, y, n_per_row, imatrix);
+ y += nblock;
+ }
+ return sizeof(block_iq2_bn)*nblock*nrows;
+}
+
+void quantize_row_iq2_bn_ref(const float * x, block_iq2_bn * y, int64_t k) {
+ quantize_iq2_bn(x, y, 1, k, nullptr);
+}
+
+void quantize_row_iq2_bn(const float * x, void * y, int64_t k) {
+ quantize_iq2_bn(x, y, 1, k, nullptr);
+}
+
+void dequantize_row_iq2_bn(const block_iq2_bn * x, float * y, int64_t k) {
+ assert(k%QK_IQ1BN == 0);
+ int nblock = k / QK_IQ1BN;
+
+ auto d1 = 1.f, d2 = 0.25f, d3 = d2*0.25f, d4 = d3*0.25f;
+ auto m = -1.f;
+ constexpr int Nj = QK_IQ1BN/4;
+ for (int i = 0; i < nblock; ++i) {
+ for (int j = 0; j < Nj; ++j) {
+ y[j+ 0] = d1*(x[i].qs[j] & 0x03) + m;
+ y[j+1*Nj] = d2*(x[i].qs[j] & 0x0c) + m;
+ y[j+2*Nj] = d3*(x[i].qs[j] & 0x30) + m;
+ y[j+3*Nj] = d4*(x[i].qs[j] & 0xc0) + m;
+ }
+ y += QK_IQ1BN;
+ }
+}
+
+namespace {
+inline int8_t iq1bn_dequant(uint8_t q, int i) {
+ uint8_t v = IQ1BNQuantizer::k_mult[i]*q;
+ //int8_t vs = (v + (v << 1)) >> 8;
+ int8_t vs = 3*v >> 8;
+ return vs - 1;
+}
+}
+
+static const int8_t iq1bn_values[1280] = {
+ -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 0, -1, -1, -1, 0, 0, -1, -1, -1, 1, 0,
+ -1, -1, -1, -1, 1, -1, -1, -1, 0, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, 0, -1, -1, 0, -1, 0, -1, -1, 1, -1, 0, -1,
+ -1, -1, 0, 0, -1, -1, 0, 0, 0, -1, -1, 1, 0, 0, -1, -1, -1, 1, 0, -1, -1, 0, 1, 0, -1, -1, 1, 1, 0, -1, -1, -1,
+ -1, 1, -1, -1, 0, 0, 0, 0, 0, 0, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, 0, 1, -1, -1, 0, 0, 1, -1, -1, 1, 0, 1,
+ -1, -1, -1, 1, 1, -1, -1, 0, 1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 0, -1, 0, -1, -1, 0, -1, 1, -1, -1, 0, -1,
+ -1, 0, -1, 0, -1, 0, 0, -1, 0, -1, 1, 0, -1, 0, -1, -1, 1, -1, 0, -1, 0, 1, -1, 0, -1, 1, 1, -1, 0, -1, -1, -1,
+ 0, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, 0, 0, -1, -1, 0, 0, 0, -1, 0, 0, 0, 0, -1, 1, 0, 0, 0,
+ -1, -1, 1, 0, 0, -1, 0, 1, 0, 0, -1, 1, 1, 0, 0, -1, -1, -1, 1, 0, -1, 0, -1, 1, 0, -1, 1, -1, 1, 0, -1, -1,
+ 0, 1, 0, -1, 0, 0, 1, 0, -1, 1, 0, 1, 0, -1, -1, 1, 1, 0, -1, 0, 1, 1, 0, -1, 1, 1, 1, 0, -1, -1, -1, -1,
+ 1, -1, 0, -1, -1, 1, -1, 1, -1, -1, 1, -1, 0, 0, 0, 0, 0, -1, 0, -1, 1, -1, 0, 0, -1, 1, -1, 1, 0, -1, 1, -1,
+ -1, 1, -1, 1, -1, 0, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 0, 1, -1, 0, -1, 0, 1, -1, 1, -1, 0, 1, -1, -1, 0,
+ 0, 1, -1, 0, 0, 0, 1, -1, 1, 0, 0, 1, -1, -1, 1, 0, 1, -1, 0, 1, 0, 1, -1, 1, 1, 0, 1, -1, -1, -1, 1, 1,
+ -1, 0, -1, 1, 1, -1, 1, -1, 1, 1, -1, 0, 0, 0, 0, 0, -1, 0, 1, 1, -1, 0, 0, 1, 1, -1, 1, 0, 1, 1, -1, -1,
+ 1, 1, 1, -1, 0, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, 0, 0, -1, -1, -1, 0, 1, -1, -1, -1, 0, -1, 0, -1,
+ -1, 0, 0, 0, -1, -1, 0, 1, 0, -1, -1, 0, -1, 1, -1, -1, 0, 0, 1, -1, -1, 0, 1, 1, -1, -1, 0, -1, -1, 0, -1, 0,
+ 0, -1, 0, -1, 0, 1, -1, 0, -1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 1, 0, 0, -1, 0, -1, 1,
+ 0, -1, 0, 0, 1, 0, -1, 0, 1, 1, 0, -1, 0, -1, -1, 1, -1, 0, 0, -1, 1, -1, 0, 1, -1, 1, -1, 0, -1, 0, 1, -1,
+ 0, 0, 0, 1, -1, 0, 1, 0, 1, -1, 0, -1, 1, 1, -1, 0, 0, 1, 1, -1, 0, 1, 1, 1, -1, 0, -1, -1, -1, 0, 0, 0,
+ -1, -1, 0, 0, 1, -1, -1, 0, 0, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, -1, 1, -1,
+ 0, 0, 0, 1, -1, 0, 0, 1, 1, -1, 0, 0, -1, -1, 0, 0, 0, 0, -1, 0, 0, 0, 1, -1, 0, 0, 0, -1, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, -1, -1, 1, 0, 0, 0, -1,
+ 1, 0, 0, 1, -1, 1, 0, 0, -1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, -1, 1, 1, 0,
+ 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, -1, -1, -1, 1, 0, 0, -1, -1, 1, 0, 1, -1, -1, 1, 0, -1, 0, -1, 1, 0, 0,
+ 0, -1, 1, 0, 1, 0, -1, 1, 0, -1, 1, -1, 1, 0, 0, 1, -1, 1, 0, 1, 1, -1, 1, 0, -1, -1, 0, 1, 0, 0, -1, 0,
+ 1, 0, 1, -1, 0, 1, 0, -1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, -1, 1, 0, 1, 0,
+ 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, -1, -1, 1, 1, 0, 0, -1, 1, 1, 0, 1, -1, 1, 1, 0, -1, 0, 1, 1, 0, 0, 0,
+ 1, 1, 0, 1, 0, 1, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, -1, -1, -1, -1, 1, 0, -1, -1, -1,
+ 1, 1, -1, -1, -1, 1, -1, 0, -1, -1, 1, 0, 0, -1, -1, 1, 1, 0, -1, -1, 1, -1, 1, -1, -1, 1, 0, 0, 0, 0, 0, 0,
+ 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 0, -1, 1, 0, -1, 0, -1, 1, 1, -1, 0, -1, 1, -1, 0, 0, -1, 1, 0, 0, 0,
+ -1, 1, 1, 0, 0, -1, 1, -1, 1, 0, -1, 1, 0, 1, 0, -1, 1, 1, 1, 0, -1, 1, -1, -1, 1, -1, 1, 0, -1, 1, -1, 1,
+ 1, -1, 1, -1, 1, -1, 0, 1, -1, 1, 0, 0, 1, -1, 1, 1, 0, 1, -1, 1, -1, 1, 1, -1, 1, 0, 0, 0, 0, 0, 0, 1,
+ 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, -1, 0, 1, 0, -1, -1, 0, 1, 1, -1, -1, 0, 1, -1, 0, -1, 0, 1, 0, 0, -1, 0,
+ 1, 1, 0, -1, 0, 1, -1, 1, -1, 0, 1, 0, 1, -1, 0, 1, 1, 1, -1, 0, 1, -1, -1, 0, 0, 1, 0, -1, 0, 0, 1, 1,
+ -1, 0, 0, 1, -1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, -1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0,
+ 0, 0, 1, 1, 0, 0, 1, -1, -1, 1, 0, 1, 0, -1, 1, 0, 1, 1, -1, 1, 0, 1, -1, 0, 1, 0, 1, 0, 0, 1, 0, 1,
+ 1, 0, 1, 0, 1, -1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, -1, -1, -1, 1, 1, 0, -1, -1, 1, 1, 1, -1,
+ -1, 1, 1, -1, 0, -1, 1, 1, 0, 0, -1, 1, 1, 1, 0, -1, 1, 1, -1, 1, -1, 1, 1, 0, 1, -1, 1, 1, 1, 1, -1, 1,
+ 1, 0, 0, 0, 0, 0, -1, -1, 0, 1, 1, 0, -1, 0, 1, 1, 1, -1, 0, 1, 1, -1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1,
+ 0, 0, 1, 1, -1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, -1, -1, 1, 1, 1, 0, -1, 1, 1, 1, 1, -1, 1,
+ 1, 1, -1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, -1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+};
+
+void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(nrc);
+
+ static_assert(QK_IQ1BN == 64, "This dot product implementation for iq1_bn requires a block size of 64");
+
+#if GGML_USE_IQK_MULMAT
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ1_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+#endif
+
+ const block_iq1_bn * x = (const block_iq1_bn *)vx;
+
+ const float * d8 = (const float *)vy;
+ const int8_t * q8 = (const int8_t *)(d8 + 4);
+ int nblock = n / QK_IQ1BN;
+
+ int sumi[8] = {};
+ int8_t q1[16];
+
+ for (int ii = 0; ii < nblock; ii += 32) {
+ int16_t sum16[8] = {};
+ int nb = std::min(ii + 32, nblock);
+ for (int i = ii; i < nb; ++i) {
+ auto ql = x[i].ql;
+ const int8_t * extra = iq1bn_values + 5*x[i].extra;
+ for (int i16 = 0; i16 < QK_IQ1BN/16; ++i16) {
+ for (int k = 0; k < 3; ++k) {
+ uint8_t q = *ql++;
+ const int8_t * vs = iq1bn_values + 5*q;
+ for (int j = 0; j < 5; ++j) q1[5*k+j] = vs[j];
+ }
+ q1[15] = extra[i16];
+ // We collect 8 q8 values per block into each element of sum16
+ // => 32 x 8 = 256 values in each loop over i, so this cannot overflow the int16_t range
+ // (q8 is in -127...127, and hence the sum is in -32512...32512
+ for (int j = 0; j < 8; ++j) sum16[j] += q8[2*j+0]*q1[2*j+0] + q8[2*j+1]*q1[2*j+1];
+ q8 += 16;
+ }
+ }
+ for (int j = 0; j < 8; ++j) sumi[j] += sum16[j];
+ }
+
+ *s = d8[0] * (sumi[0] + sumi[1]) + d8[1] * (sumi[2] + sumi[3]) + d8[2] * (sumi[4] + sumi[5]) + d8[3] * (sumi[6] + sumi[7]);
+}
+
+void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) {
+
+ GGML_ASSERT(nrc == 1);
+ GGML_UNUSED(bs);
+ GGML_UNUSED(bx);
+ GGML_UNUSED(by);
+ GGML_UNUSED(nrc);
+
+ static_assert(QK_IQ1BN == 64, "This dot product implementation for iq2_bn requires a block size of 64");
+
+ if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) {
+ return;
+ }
+
+ constexpr int Nj = QK_IQ1BN/4;
+
+ const block_iq2_bn * x = (const block_iq2_bn *)vx;
+ int nblock = n / QK_IQ1BN;
+
+ const float * d = (const float *)vy;
+ const int8_t * q8 = (const int8_t *)(d + 4);
+
+ int sum[16] = { };
+ int sum0[4] = { };
+
+ for (int i = 0; i < nblock; ++i) {
+ for (int j = 0; j < Nj/4; ++j) {
+ for (int l = 0; l < 4; ++l) {
+ sum[4*j + 0] += q8[4*j + l + 0] * (x[i].qs[4*j+l] & 0x03);
+ sum[4*j + 1] += q8[4*j + l + 1*Nj] * (x[i].qs[4*j+l] & 0x0c);
+ sum[4*j + 2] += q8[4*j + l + 2*Nj] * (x[i].qs[4*j+l] & 0x30);
+ sum[4*j + 3] += q8[4*j + l + 3*Nj] * (x[i].qs[4*j+l] & 0xc0);
+ sum0[j] += q8[4*j + l] + q8[4*j + l + 1*Nj] + q8[4*j + l + 2*Nj] + q8[4*j + l + 3*Nj];
+ }
+ }
+ q8 += QK_IQ1BN;
+ }
+
+ float sumf = 0;
+ for (int j = 0; j < 4; ++j) {
+ sumf += d[j] * (sum[4*j + 0] + 0.25f*sum[4*j + 1] + 0.0625*sum[4*j + 2] + 0.015625*sum[4*j + 3] - sum0[j]);
+ }
+ *s = sumf;
+
+}
+
+void quantize_row_q8_K64_ref(const float * x, block_q8_K64 * y, int64_t k) {
+
+ float * dptr = (float *)y;
+ auto qs = (int8_t *)(dptr + 4);
+#ifdef __ARM_NEON
+ static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60};
+ auto shuffle = vld1q_u8(k_shuffle);
+ float32x4_t max[4] = { };
+ for (int j = 0; j < k; j += 16) {
+ for (int i = 0; i < 4; ++i) {
+ auto val = vld1q_f32(x + j + 4*i);
+ val = vabsq_f32(val);
+ max[i] = vmaxq_f32(max[i], val);
+ }
+ }
+ float32x4_t vid[4];
+ for (int i = 0; i < 4; ++i) {
+ dptr[i] = vmaxvq_f32(max[i])/127;
+ float id = dptr[i] > 0 ? 1/dptr[i] : 0.f;
+ vid[i] = vdupq_n_f32(id);
+ }
+ int8x16x4_t q;
+ for (int j = 0; j < k; j += 16) {
+ for (int i = 0; i < 4; ++i) {
+ auto val = vld1q_f32(x + j + 4*i);
+ val = vmulq_f32(vid[i], val);
+ q.val[i] = vreinterpretq_s8_s32(vcvtnq_s32_f32(val));
+ }
+ auto qi = vqtbl4q_s8(q, shuffle);
+ vst1q_s8(qs, qi);
+ qs += 16;
+ }
+#elif defined __AVX__
+ __m128 max[4] = {};
+ __m128 sign_bit = _mm_set1_ps(-0.f);
+ for (int j = 0; j < k; j += 16) {
+ for (int i = 0; i < 4; ++i) {
+ auto val = _mm_loadu_ps(x + j + 4*i);
+ val = _mm_andnot_ps(sign_bit, val);
+ max[i] = _mm_max_ps(max[i], val);
+ }
+ }
+ __m128 vid[4];
+ for (int i = 0; i < 4; ++i) {
+ max[i] = _mm_max_ps(max[i], _mm_movehl_ps(max[i], max[i]));
+ max[i] = _mm_max_ss(max[i], _mm_movehdup_ps(max[i]));
+ float maxi = _mm_cvtss_f32(max[i]);
+ dptr[i] = maxi/127;
+ float id = dptr[i] > 0 ? 1/dptr[i] : 0.f;
+ vid[i] = _mm_set1_ps(id);
+ }
+ __m128i q[4];
+ for (int j = 0; j < k; j += 16) {
+ for (int i = 0; i < 4; ++i) {
+ auto val = _mm_loadu_ps(x + j + 4*i);
+ val = _mm_round_ps(_mm_mul_ps(vid[i], val), _MM_ROUND_NEAREST);
+ q[i] = _mm_cvtps_epi32(val);
+ }
+ auto q1 = _mm_packs_epi32(q[0], q[1]);
+ auto q2 = _mm_packs_epi32(q[2], q[3]);
+ auto qi = _mm_packs_epi16(q1, q2);
+ _mm_storeu_si128((__m128i *)qs, qi);
+ qs += 16;
+ }
+#else
+ float aux[4] = {0.f, 0.f, 0.f, 0.f};
+ for (int j = 0; j < k; j += 16) {
+ for (int i = 0; i < 4; ++i) {
+ for (int l = 0; l < 4; ++l) {
+ float ax = fabsf(x[j+4*i+l]);
+ aux[i] = std::max(aux[i], ax);
+ }
+ }
+ }
+ for (int i = 0; i < 4; ++i) {
+ dptr[i] = aux[i]/127;
+ aux[i] = dptr[i] > 0 ? 1/dptr[i] : 0.f;
+ }
+ for (int j = 0; j < k; j += 16) {
+ for (int i = 0; i < 4; ++i) {
+ for (int l = 0; l < 4; ++l) qs[j+4*i+l] = nearest_int(aux[i]*x[j+4*i+l]);
+ }
+ }
+#endif
+}
+
+void quantize_row_q8_K64(const float * x, void * y, int64_t k) {
+ quantize_row_q8_K64_ref(x, (block_q8_K64 *)y, k);
+}
+
diff --git a/ggml/src/kompute b/ggml/src/kompute
new file mode 160000
+Subproject 4565194ed7c32d1d2efa32ceab4d3c6cae00630
diff --git a/ggml/src/kompute-shaders/common.comp b/ggml/src/kompute-shaders/common.comp
new file mode 100644
index 00000000..62d62b02
--- /dev/null
+++ b/ggml/src/kompute-shaders/common.comp
@@ -0,0 +1,102 @@
+#extension GL_EXT_shader_16bit_storage: require
+#extension GL_EXT_shader_8bit_storage: require
+#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
+#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
+#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
+#extension GL_EXT_control_flow_attributes: enable
+#extension GL_KHR_shader_subgroup_arithmetic : require
+#extension GL_EXT_debug_printf : enable
+
+#define QK4_0 32
+#define QK4_1 32
+
+#define GELU_COEF_A 0.044715
+#define SQRT_2_OVER_PI 0.79788456080286535587989211986876
+#define TWOPI_F 6.283185307179586f
+
+#define QK_K 256
+
+#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
+#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
+#define u8BufToU32(buf, idx) (((uint32_t u8BufToU16(buf, idx + 2) << 8 | buf[idx + 1]) << 8) | buf[idx])
+#define u8BufToFloat(buf, idx) uintBitsToFloat u8BufToU32(buf, idx)
+
+#define sizeof_block_q4_0 0x12
+struct block_q4_0 {
+ float16_t d;
+ uint8_t qs[QK4_0 / 2];
+};
+mat4 dequantize_q4_0(const block_q4_0 xb, uint il) {
+ const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
+ const float d2 = d1 / 256.f;
+ const float md = -8.f * xb.d;
+ const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
+ const uint16_t mask1 = mask0 << 8;
+
+ mat4 reg;
+ for (int i=0;i<8;i++) {
+ uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
+ reg[i/2][2*(i%2)+0] = d1 * (b & mask0) + md;
+ reg[i/2][2*(i%2)+1] = d2 * (b & mask1) + md;
+ }
+ return reg;
+}
+
+#define sizeof_block_q4_1 0x14
+struct block_q4_1 {
+ float16_t d;
+ float16_t m;
+ uint8_t qs[QK4_1 / 2];
+};
+mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
+ const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
+ const float d2 = d1 / 256.f;
+ const float m = xb.m;
+ const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
+ const uint16_t mask1 = mask0 << 8;
+
+ mat4 reg;
+ for (int i=0;i<8;i++) {
+ uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
+ reg[i/2][2*(i%2)+0] = ((b & mask0) * d1) + m;
+ reg[i/2][2*(i%2)+1] = ((b & mask1) * d2) + m;
+ }
+ return reg;
+}
+
+#define sizeof_block_q6_k 210
+struct block_q6_k {
+ uint8_t ql[QK_K/2]; // quants, lower 4 bits
+ uint8_t qh[QK_K/4]; // quants, upper 2 bits
+ int8_t scales[QK_K/16]; // scales, quantized with 8 bits
+ float16_t d; // super-block scale
+};
+mat4 dequantize_q6_k(const block_q6_k xb, uint il) {
+ const float16_t d_all = xb.d;
+
+ const uint qlIndex = 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
+ const uint qhIndex = 32*(il/8) + 16*(il&1);
+ float16_t sc = xb.scales[(il%2) + 2 * ((il/2))];
+ il = (il/2) & 3;
+
+ const uint16_t kmask1 = il>1 ? uint16_t(il>2 ? 192 : 48) : uint16_t(il>0 ? 12 : 3);
+ const uint16_t kmask2 = il>1 ? uint8_t(0xF0) : uint8_t(0x0F);
+ const float16_t coef = il>1 ? float16_t(1.f/16.f) : float16_t(1.f);
+ const float16_t ml = float16_t(d_all * sc * 32.f);
+ const float16_t dl = float16_t(d_all * sc * coef);
+ mat4 reg;
+ for (int i = 0; i < 16; ++i) {
+ const float16_t q = (il&1) != 0 ? ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 2))
+ : ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 4));
+ reg[i/4][i%4] = dl * q - ml;
+ }
+ return reg;
+}
+
+
+#define QK8_0 32
+// struct block_q8_0 {
+// float16_t d; // delta
+// int8_t qs[QK8_0]; // quants
+// };
+#define sizeof_block_q8_0 34
diff --git a/ggml/src/kompute-shaders/op_add.comp b/ggml/src/kompute-shaders/op_add.comp
new file mode 100644
index 00000000..b7b76a79
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_add.comp
@@ -0,0 +1,58 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1024) in;
+
+layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
+layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
+layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+ uint inAOff;
+ uint inBOff;
+ uint outOff;
+ int ne00;
+ int nb00;
+ int nb01;
+ int nb02;
+ int nb03;
+ int ne10;
+ int ne11;
+ int ne12;
+ int ne13;
+ int nb10;
+ int nb11;
+ int nb12;
+ int nb13;
+ int ne0;
+ int nb0;
+ int nb1;
+ int nb2;
+ int nb3;
+ //int offs; // TODO: needed for GGML_OP_ACC, see metal code
+} pcs;
+
+// general-purpose kernel for addition of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
+// cons: not very efficient
+void main() {
+ const uint i03 = gl_WorkGroupID.z;
+ const uint i02 = gl_WorkGroupID.y;
+ const uint i01 = gl_WorkGroupID.x;
+
+ const uint i13 = i03 % pcs.ne13;
+ const uint i12 = i02 % pcs.ne12;
+ const uint i11 = i01 % pcs.ne11;
+
+ int offs = 0; // TMP (see above)
+
+ uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + offs) / 4);
+ uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11 ) / 4);
+ uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1 + offs) / 4);
+
+ for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
+ const uint i10 = i0 % pcs.ne10;
+ out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] + inB[pcs.inBOff + src1_off + i10];
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_addrow.comp b/ggml/src/kompute-shaders/op_addrow.comp
new file mode 100644
index 00000000..2376a6b8
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_addrow.comp
@@ -0,0 +1,25 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
+layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
+layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+ uint inAOff;
+ uint inBOff;
+ uint outOff;
+ uint row;
+} pcs;
+
+void main() {
+ const uint baseIndex = gl_WorkGroupID.x * 4;
+
+ for (uint x = 0; x < 4; x++) {
+ const uint i = baseIndex + x;
+ out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff];
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_cpy_f16_f16.comp b/ggml/src/kompute-shaders/op_cpy_f16_f16.comp
new file mode 100644
index 00000000..d57247d2
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_cpy_f16_f16.comp
@@ -0,0 +1,52 @@
+#version 450
+
+#include "common.comp"
+
+#define IN_TYPE float16_t
+#define IN_TYPE_SIZE 2
+#define OUT_TYPE float16_t
+#define OUT_TYPE_SIZE 2
+
+layout(local_size_x = 1024) in;
+
+layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
+layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
+
+layout (push_constant) uniform parameter {
+ uint inOff;
+ uint outOff;
+ int ne00;
+ int ne01;
+ int ne02;
+ uint nb00;
+ uint nb01;
+ uint nb02;
+ uint nb03;
+ int ne0;
+ int ne1;
+ int ne2;
+ uint nb0;
+ uint nb1;
+ uint nb2;
+ uint nb3;
+} pcs;
+
+void main() {
+ const uint i03 = gl_WorkGroupID.z;
+ const uint i02 = gl_WorkGroupID.y;
+ const uint i01 = gl_WorkGroupID.x;
+
+ const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
+
+ const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
+ const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
+ const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
+ const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
+
+ const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
+
+ for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+ const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
+ out_[dst_data+i00] = OUT_TYPE(in_[src]);
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_cpy_f16_f32.comp b/ggml/src/kompute-shaders/op_cpy_f16_f32.comp
new file mode 100644
index 00000000..b568bcd7
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_cpy_f16_f32.comp
@@ -0,0 +1,52 @@
+#version 450
+
+#include "common.comp"
+
+#define IN_TYPE float16_t
+#define IN_TYPE_SIZE 2
+#define OUT_TYPE float
+#define OUT_TYPE_SIZE 4
+
+layout(local_size_x = 1024) in;
+
+layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
+layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
+
+layout (push_constant) uniform parameter {
+ uint inOff;
+ uint outOff;
+ int ne00;
+ int ne01;
+ int ne02;
+ uint nb00;
+ uint nb01;
+ uint nb02;
+ uint nb03;
+ int ne0;
+ int ne1;
+ int ne2;
+ uint nb0;
+ uint nb1;
+ uint nb2;
+ uint nb3;
+} pcs;
+
+void main() {
+ const uint i03 = gl_WorkGroupID.z;
+ const uint i02 = gl_WorkGroupID.y;
+ const uint i01 = gl_WorkGroupID.x;
+
+ const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
+
+ const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
+ const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
+ const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
+ const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
+
+ const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
+
+ for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+ const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
+ out_[dst_data+i00] = OUT_TYPE(in_[src]);
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_cpy_f32_f16.comp b/ggml/src/kompute-shaders/op_cpy_f32_f16.comp
new file mode 100644
index 00000000..99b22834
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_cpy_f32_f16.comp
@@ -0,0 +1,52 @@
+#version 450
+
+#include "common.comp"
+
+#define IN_TYPE float
+#define IN_TYPE_SIZE 4
+#define OUT_TYPE float16_t
+#define OUT_TYPE_SIZE 2
+
+layout(local_size_x = 1024) in;
+
+layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
+layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
+
+layout (push_constant) uniform parameter {
+ uint inOff;
+ uint outOff;
+ int ne00;
+ int ne01;
+ int ne02;
+ uint nb00;
+ uint nb01;
+ uint nb02;
+ uint nb03;
+ int ne0;
+ int ne1;
+ int ne2;
+ uint nb0;
+ uint nb1;
+ uint nb2;
+ uint nb3;
+} pcs;
+
+void main() {
+ const uint i03 = gl_WorkGroupID.z;
+ const uint i02 = gl_WorkGroupID.y;
+ const uint i01 = gl_WorkGroupID.x;
+
+ const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
+
+ const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
+ const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
+ const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
+ const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
+
+ const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
+
+ for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+ const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
+ out_[dst_data+i00] = OUT_TYPE(in_[src]);
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_cpy_f32_f32.comp b/ggml/src/kompute-shaders/op_cpy_f32_f32.comp
new file mode 100644
index 00000000..2fc99849
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_cpy_f32_f32.comp
@@ -0,0 +1,52 @@
+#version 450
+
+#include "common.comp"
+
+#define IN_TYPE float
+#define IN_TYPE_SIZE 4
+#define OUT_TYPE float
+#define OUT_TYPE_SIZE 4
+
+layout(local_size_x = 1024) in;
+
+layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
+layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
+
+layout (push_constant) uniform parameter {
+ uint inOff;
+ uint outOff;
+ int ne00;
+ int ne01;
+ int ne02;
+ uint nb00;
+ uint nb01;
+ uint nb02;
+ uint nb03;
+ int ne0;
+ int ne1;
+ int ne2;
+ uint nb0;
+ uint nb1;
+ uint nb2;
+ uint nb3;
+} pcs;
+
+void main() {
+ const uint i03 = gl_WorkGroupID.z;
+ const uint i02 = gl_WorkGroupID.y;
+ const uint i01 = gl_WorkGroupID.x;
+
+ const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
+
+ const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
+ const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
+ const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
+ const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
+
+ const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
+
+ for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+ const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
+ out_[dst_data+i00] = OUT_TYPE(in_[src]);
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_diagmask.comp b/ggml/src/kompute-shaders/op_diagmask.comp
new file mode 100644
index 00000000..291c3fc1
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_diagmask.comp
@@ -0,0 +1,30 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+ uint inOff;
+ uint outOff;
+ uint n_past;
+ int ne00;
+ int ne01;
+} pcs;
+
+void main() {
+ const uint i02 = gl_WorkGroupID.z;
+ const uint i01 = gl_WorkGroupID.y;
+ const uint i00 = gl_WorkGroupID.x;
+
+ const uint index = i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00;
+
+ if (i00 > pcs.n_past + i01) {
+ out_[index + pcs.outOff] = uintBitsToFloat(0xFF800000);
+ } else {
+ out_[index + pcs.outOff] = in_[index + pcs.inOff];
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_gelu.comp b/ggml/src/kompute-shaders/op_gelu.comp
new file mode 100644
index 00000000..9d8c5371
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_gelu.comp
@@ -0,0 +1,22 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
+layout(push_constant) uniform PushConstants {
+ uint inOff;
+ uint outOff;
+} pcs;
+
+void main() {
+ const uint baseIndex = gl_WorkGroupID.x * 8;
+
+ for (uint x = 0; x < 8; x++) {
+ const uint i = baseIndex + x;
+ const float y = in_[i + pcs.inOff];
+ out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(clamp(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y), -15.0, 15.0)));
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_getrows.comp b/ggml/src/kompute-shaders/op_getrows.comp
new file mode 100644
index 00000000..1a5581b2
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_getrows.comp
@@ -0,0 +1,17 @@
+void main() {
+ const uint i = gl_WorkGroupID.x;
+ const int r = inB[i + pcs.inBOff];
+
+ int z = 0;
+ for (uint ind = gl_LocalInvocationID.x; ind < pcs.ne00/16; ind += gl_WorkGroupSize.x) {
+ const uint inIndex = (r * pcs.nb01 + pcs.inAOff) + ind/NL * SIZE_OF_BLOCK;
+ const mat4 result = dequantize_block(inIndex, ind%NL);
+ for (uint j = 0; j < 4; ++j) {
+ for (uint k = 0; k < 4; ++k) {
+ const uint outIndex = i * pcs.nb1/BYTES_FOR_TYPE + pcs.outOff + z;
+ out_[outIndex] = result[j][k];
+ ++z;
+ }
+ }
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_getrows_f16.comp b/ggml/src/kompute-shaders/op_getrows_f16.comp
new file mode 100644
index 00000000..48c93610
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_getrows_f16.comp
@@ -0,0 +1,31 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
+layout (binding = 1) readonly buffer tensorInB { int inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+ uint inAOff;
+ uint inBOff;
+ uint outOff;
+ int ne00;
+ int nb01;
+ int nb1;
+} pcs;
+
+void dequantize_row_f16(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
+ for (int j = 0; j < k; j++) {
+ out_[y + j] = inA[x + j];
+ }
+}
+
+void main() {
+ const uint i = gl_WorkGroupID.x;
+ const int r = inB[i + pcs.inBOff];
+
+ dequantize_row_f16(r*pcs.nb01/2/*bytes for float16*/ + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
+}
diff --git a/ggml/src/kompute-shaders/op_getrows_f32.comp b/ggml/src/kompute-shaders/op_getrows_f32.comp
new file mode 100644
index 00000000..9d7acdaf
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_getrows_f32.comp
@@ -0,0 +1,31 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout (binding = 0) readonly buffer tensorInA { float inA[]; };
+layout (binding = 1) readonly buffer tensorInB { int inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+ uint inAOff;
+ uint inBOff;
+ uint outOff;
+ int ne00;
+ int nb01;
+ int nb1;
+} pcs;
+
+void dequantize_row_f32(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
+ for (int j = 0; j < k; j++) {
+ out_[y + j] = inA[x + j];
+ }
+}
+
+void main() {
+ const uint i = gl_WorkGroupID.x;
+ const int r = inB[i + pcs.inBOff];
+
+ dequantize_row_f32(r*pcs.nb01/4 + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
+}
diff --git a/ggml/src/kompute-shaders/op_getrows_q4_0.comp b/ggml/src/kompute-shaders/op_getrows_q4_0.comp
new file mode 100644
index 00000000..32b2e891
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_getrows_q4_0.comp
@@ -0,0 +1,38 @@
+#version 450
+
+#include "common.comp"
+
+#define NL 2
+#define BYTES_FOR_TYPE 4 /*bytes for float*/
+#define SIZE_OF_BLOCK sizeof_block_q4_0
+
+layout(local_size_x = 1) in;
+
+layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
+layout (binding = 1) readonly buffer tensorInB { int inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+ uint inAOff;
+ uint inBOff;
+ uint outOff;
+ int ne00;
+ int nb01;
+ int nb1;
+} pcs;
+
+block_q4_0 get_unaligned_block_q4_0(uint index) {
+ block_q4_0 fres;
+ fres.d = u8BufToFloat16(inA, index);
+ [[unroll]] for (uint it = 0; it != QK4_0 / 2; it++) {
+ fres.qs[it] = inA[index+2+it];
+ }
+ return fres;
+}
+
+mat4 dequantize_block(uint index, uint il) {
+ const block_q4_0 block = get_unaligned_block_q4_0(index);
+ return dequantize_q4_0(block, il);
+}
+
+#include "op_getrows.comp"
diff --git a/ggml/src/kompute-shaders/op_getrows_q4_1.comp b/ggml/src/kompute-shaders/op_getrows_q4_1.comp
new file mode 100644
index 00000000..87f2fbe1
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_getrows_q4_1.comp
@@ -0,0 +1,39 @@
+#version 450
+
+#include "common.comp"
+
+#define NL 2
+#define BYTES_FOR_TYPE 4 /*bytes for float*/
+#define SIZE_OF_BLOCK sizeof_block_q4_1
+
+layout(local_size_x = 1) in;
+
+layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
+layout (binding = 1) readonly buffer tensorInB { int inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+ uint inAOff;
+ uint inBOff;
+ uint outOff;
+ int ne00;
+ int nb01;
+ int nb1;
+} pcs;
+
+block_q4_1 get_unaligned_block_q4_1(uint index) {
+ block_q4_1 fres;
+ fres.d = u8BufToFloat16(inA, index);
+ fres.m = u8BufToFloat16(inA, index+2);
+ [[unroll]] for (uint it = 0; it != QK4_1 / 2; it++) {
+ fres.qs[it] = inA[index+4+it];
+ }
+ return fres;
+}
+
+mat4 dequantize_block(uint index, uint il) {
+ const block_q4_1 block = get_unaligned_block_q4_1(index);
+ return dequantize_q4_1(block, il);
+}
+
+#include "op_getrows.comp"
diff --git a/ggml/src/kompute-shaders/op_getrows_q6_k.comp b/ggml/src/kompute-shaders/op_getrows_q6_k.comp
new file mode 100644
index 00000000..9ce3545d
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_getrows_q6_k.comp
@@ -0,0 +1,44 @@
+#version 450
+
+#include "common.comp"
+
+#define NL 16
+#define BYTES_FOR_TYPE 4 /*bytes for float*/
+#define SIZE_OF_BLOCK sizeof_block_q6_k
+
+layout(local_size_x = 1) in;
+
+layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
+layout (binding = 1) readonly buffer tensorInB { int inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+ uint inAOff;
+ uint inBOff;
+ uint outOff;
+ int ne00;
+ int nb01;
+ int nb1;
+} pcs;
+
+block_q6_k get_unaligned_block_q6_k(uint index) {
+ block_q6_k fres;
+ [[unroll]] for (uint it = 0; it != QK_K / 2; it++) {
+ fres.ql[it] = inA[index + it];
+ }
+ [[unroll]] for (uint it = 0; it != QK_K / 4; it++) {
+ fres.qh[it] = inA[index + QK_K/2 + it];
+ }
+ [[unroll]] for (uint it = 0; it != QK_K / 16; it++) {
+ fres.scales[it] = int8_t(inA[index + QK_K/2 + QK_K/4 + it]);
+ }
+ fres.d = u8BufToFloat16(inA, index + QK_K/2 + QK_K/4 + QK_K/16);
+ return fres;
+}
+
+mat4 dequantize_block(uint index, uint il) {
+ const block_q6_k block = get_unaligned_block_q6_k(index);
+ return dequantize_q6_k(block, il);
+}
+
+#include "op_getrows.comp"
diff --git a/ggml/src/kompute-shaders/op_mul.comp b/ggml/src/kompute-shaders/op_mul.comp
new file mode 100644
index 00000000..c92647c4
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_mul.comp
@@ -0,0 +1,52 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1024) in;
+
+layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
+layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
+layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+ uint inAOff;
+ uint inBOff;
+ uint outOff;
+ int ne00;
+ int nb00;
+ int nb01;
+ int nb02;
+ int nb03;
+ int ne10;
+ int ne11;
+ int ne12;
+ int ne13;
+ int nb10;
+ int nb11;
+ int nb12;
+ int nb13;
+ int ne0;
+ int nb0;
+ int nb1;
+ int nb2;
+ int nb3;
+} pcs;
+
+void main() {
+ const uint i03 = gl_WorkGroupID.z;
+ const uint i02 = gl_WorkGroupID.y;
+ const uint i01 = gl_WorkGroupID.x;
+
+ const uint i13 = i03 % pcs.ne13;
+ const uint i12 = i02 % pcs.ne12;
+ const uint i11 = i01 % pcs.ne11;
+
+ uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01) / 4);
+ uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11) / 4);
+ uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1) / 4);
+
+ for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
+ const uint i10 = i0 % pcs.ne10;
+ out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] * inB[pcs.inBOff + src1_off + i10];
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_mul_mat_f16.comp b/ggml/src/kompute-shaders/op_mul_mat_f16.comp
new file mode 100644
index 00000000..8f0a9031
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_mul_mat_f16.comp
@@ -0,0 +1,67 @@
+#version 450
+
+#include "common.comp"
+
+#extension GL_KHR_shader_subgroup_arithmetic : require
+
+layout(local_size_x_id = 0) in;
+
+layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
+layout (binding = 1) readonly buffer tensorInB { float inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+ uint inAOff;
+ uint inBOff;
+ uint outOff;
+ int ne00;
+ int ne01;
+ int ne02;
+ uint nb00;
+ uint nb01;
+ uint nb02;
+ int ne10;
+ int ne11;
+ int ne12;
+ uint nb10;
+ uint nb11;
+ uint nb12;
+ int ne0;
+ int ne1;
+ uint r2;
+ uint r3;
+} pcs;
+
+#define N_F16_F32 4
+
+void main() {
+ const uint r0 = gl_WorkGroupID.x;
+ const uint rb = gl_WorkGroupID.y*N_F16_F32;
+ const uint im = gl_WorkGroupID.z;
+
+ const uint i12 = im%pcs.ne12;
+ const uint i13 = im/pcs.ne12;
+
+ const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb02*pcs.ne02;
+
+ const uint x = offset0 / 2 + pcs.inAOff; // Based from inA
+
+ for (uint row = 0; row < N_F16_F32; ++row) {
+ uint r1 = rb + row;
+ if (r1 >= pcs.ne11) {
+ break;
+ }
+
+ const uint y = (r1*pcs.nb11 + im*pcs.nb12) / 4 + pcs.inBOff; // Based from inB
+
+ float sumf = 0;
+ for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
+ sumf += float(inA[x+i]) * float(inB[y+i]);
+ }
+
+ const float all_sum = subgroupAdd(sumf);
+ if (subgroupElect()) {
+ out_[im*pcs.ne1*pcs.ne0 + r1*pcs.ne0 + r0 + pcs.outOff] = all_sum;
+ }
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_mul_mat_mat_f32.comp b/ggml/src/kompute-shaders/op_mul_mat_mat_f32.comp
new file mode 100644
index 00000000..d1ca4ad6
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_mul_mat_mat_f32.comp
@@ -0,0 +1,51 @@
+#version 450
+
+#include "common.comp"
+
+#extension GL_KHR_shader_subgroup_arithmetic : require
+#extension GL_EXT_debug_printf : enable
+
+// device subgroup size
+layout (local_size_x_id = 0) in;
+
+layout(binding = 0) readonly buffer tensorInA { float inA[]; };
+layout(binding = 1) readonly buffer tensorInB { float inB[]; };
+layout(binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout(push_constant) uniform parameter {
+ uint inAOff;
+ uint inBOff;
+ uint outOff;
+ int ne00;
+ int ne01;
+ int ne02;
+ int ne11;
+ int ne12;
+ uint nb01;
+ uint nb02;
+ uint nb11;
+ uint nb12;
+ uint nb1;
+ uint nb2;
+}
+pcs;
+
+
+void main() {
+ uvec3 gid = gl_WorkGroupID;
+
+ uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z;
+ uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z;
+
+ const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 4 + pcs.inAOff; // Based from inA
+ const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB
+ float sum = 0.0f;
+ for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
+ sum += float(inA[x+i]) * float(inB[y+i]);
+ }
+
+ const float all_sum = subgroupAdd(sum);
+ if (subgroupElect()) {
+ out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum;
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_mul_mat_q4_0.comp b/ggml/src/kompute-shaders/op_mul_mat_q4_0.comp
new file mode 100644
index 00000000..b0cea8bb
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_mul_mat_q4_0.comp
@@ -0,0 +1,33 @@
+#version 450
+
+#include "common.comp"
+
+#define BLOCKS_IN_QUANT QK4_0
+#define SIZE_OF_BLOCK sizeof_block_q4_0
+#define N_ROWS 4
+
+#include "op_mul_mv_q_n_pre.comp"
+
+// The q4_0 version of this function
+float block_q_n_dot_y(uint block_index, uint yb, uint il) {
+ vec2 acc = vec2(0.0, 0.0);
+ const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
+ float d = float(u8BufToFloat16(inA, index));
+ float sumy = 0.0f;
+ for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
+ const uint16_t b = u8BufToU16(inA, index + 2 + il + i);
+
+ const float yl0 = inB[yb + i];
+ const float yl1 = inB[yb + i + 1];
+ const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
+ const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
+
+ sumy += yl0 + yl1 + yl8 + yl9;
+
+ acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
+ acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
+ }
+ return d * (sumy * -8.f + acc[0] + acc[1]);
+}
+
+#include "op_mul_mv_q_n.comp"
diff --git a/ggml/src/kompute-shaders/op_mul_mat_q4_1.comp b/ggml/src/kompute-shaders/op_mul_mat_q4_1.comp
new file mode 100644
index 00000000..8582c61a
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_mul_mat_q4_1.comp
@@ -0,0 +1,35 @@
+#version 450
+
+#include "common.comp"
+
+#define BLOCKS_IN_QUANT QK4_1
+#define SIZE_OF_BLOCK sizeof_block_q4_1
+#define N_ROWS 4
+
+#include "op_mul_mv_q_n_pre.comp"
+
+// The q4_1 version of this function
+float block_q_n_dot_y(uint block_index, uint yb, uint il) {
+ vec2 acc = vec2(0.0, 0.0);
+ const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
+ float d = float(u8BufToFloat16(inA, index));
+ float m = float(u8BufToFloat16(inA, index+2));
+
+ float sumy = 0.0f;
+ for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
+ const uint16_t b = u8BufToU16(inA, index + 4 + il + i);
+
+ const float yl0 = inB[yb + i];
+ const float yl1 = inB[yb + i + 1];
+ const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
+ const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
+
+ sumy += yl0 + yl1 + yl8 + yl9;
+
+ acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
+ acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
+ }
+ return d * (acc[0] + acc[1]) + sumy * m;
+}
+
+#include "op_mul_mv_q_n.comp"
diff --git a/ggml/src/kompute-shaders/op_mul_mat_q6_k.comp b/ggml/src/kompute-shaders/op_mul_mat_q6_k.comp
new file mode 100644
index 00000000..c9baebdf
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_mul_mat_q6_k.comp
@@ -0,0 +1,94 @@
+#version 450
+
+#include "common.comp"
+
+#define SIZE_OF_BLOCK sizeof_block_q6_k
+
+layout(local_size_x_id = 0) in;
+layout(local_size_y_id = 1) in;
+layout(local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
+layout (binding = 1) readonly buffer tensorInB { float inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+ uint inAOff;
+ uint inBOff;
+ uint outOff;
+ int ne00;
+ int ne10;
+ int ne0;
+ int ne1;
+ int ne01;
+ int gqa;
+} pcs;
+
+void main() {
+ const uint8_t kmask1 = uint8_t(0x03);
+ const uint8_t kmask2 = uint8_t(0x0C);
+ const uint8_t kmask3 = uint8_t(0x30);
+ const uint8_t kmask4 = uint8_t(0xC0);
+
+ const uint nb = pcs.ne00/QK_K;
+
+ const uint r0 = gl_WorkGroupID.x;
+ const uint r1 = gl_WorkGroupID.y;
+ const uint r2 = gl_WorkGroupID.z;
+
+ const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
+ const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0);
+ const uint x = row * nb + offset0; // Based from inA without base offset
+ const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
+
+ float sumf = 0;
+
+ // bits of invocation ID for gl_SubgroupSize=32:
+ // x x x x x
+ // 4 3 2 1 0
+ // ( tid ) ix
+ // ip ( il )
+
+ const uint block_stride = gl_SubgroupSize / 16; // number of blocks each subgroup processes
+ const uint tid = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0
+ const uint ix = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1
+ const uint ip = tid/8; // first or second half of block (0 or 1)
+ const uint il = tid%8; // each half has 8 parts, one per scale
+ const uint n = 4; // 4 scales at a time (and 4 sums)
+ const uint l0 = n*il; // offset into half-block, 0..28
+ const uint is = 8*ip + l0/16; // 0, 1, 8, 9
+
+ const uint y_offset = 128*ip + l0;
+ const uint q_offset_l = 64*ip + l0;
+ const uint q_offset_h = 32*ip + l0;
+
+ for (uint i = ix; i < nb; i += block_stride) {
+
+ const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff;
+
+ const uint qlIndex = q_offset_l;
+ const uint q2Index = qlIndex + QK_K/8;
+ const uint qhIndex = q_offset_h;
+ const uint y = yy + i * QK_K + y_offset;
+
+ float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f};
+ for (uint l = 0; l < n; ++l) {
+ const uint8_t currentQ1 = inA[baseIndex + qlIndex + l];
+ const uint8_t currentQ2 = inA[baseIndex + q2Index + l];
+ const uint8_t currentQh = inA[baseIndex + QK_K/2 + qhIndex + l];
+
+ sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32);
+ sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32);
+ sums[2] += inB[y+l+64] * (int8_t((currentQ1 >> 4) | ((currentQh & kmask3) << 0)) - 32);
+ sums[3] += inB[y+l+96] * (int8_t((currentQ2 >> 4) | ((currentQh & kmask4) >> 2)) - 32);
+ }
+
+ float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16);
+ sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is]));
+ }
+
+ const float tot = subgroupAdd(sumf);
+ if (subgroupElect()) {
+ out_[r1*pcs.ne0 + r2*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_mul_mat_q8_0.comp b/ggml/src/kompute-shaders/op_mul_mat_q8_0.comp
new file mode 100644
index 00000000..34d015e9
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_mul_mat_q8_0.comp
@@ -0,0 +1,73 @@
+#version 450
+
+#include "common.comp"
+
+#include "op_mul_mv_q_n_pre.comp"
+
+#define SIZE_OF_D 2
+
+#define N_DST 4 // each SIMD group works on 4 rows
+#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
+
+#define NB_Q8_0 8
+
+void main() {
+ // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
+ if (gl_SubgroupInvocationID > 31)
+ return;
+
+ const int nr = N_DST;
+ const int nsg = N_SIMDGROUP;
+ const int nw = N_SIMDWIDTH;
+
+ const int nb = pcs.ne00/QK8_0;
+ const uint r0 = gl_WorkGroupID.x;
+ const uint r1 = gl_WorkGroupID.y;
+ const uint im = gl_WorkGroupID.z;
+
+ const uint first_row = (r0 * nsg + gl_SubgroupID) * nr;
+
+ const uint i12 = im%pcs.ne12;
+ const uint i13 = im/pcs.ne12;
+
+ const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
+
+ const uint x = offset0*sizeof_block_q8_0 + pcs.inAOff; // Based from inA
+ const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; // based from inB
+
+ float yl[NB_Q8_0];
+ float sumf[N_DST]={0.f, 0.f, 0.f, 0.f};
+
+ const uint ix = gl_SubgroupInvocationID.x/4;
+ const uint il = gl_SubgroupInvocationID.x%4;
+
+ uint yb = y + ix * QK8_0 + NB_Q8_0*il;
+
+ // each thread in a SIMD group deals with NB_Q8_0 quants at a time
+ for (uint ib = ix; ib < nb; ib += nw/4) {
+ for (int i = 0; i < NB_Q8_0; ++i) {
+ yl[i] = inB[yb + i];
+ }
+
+ for (int row = 0; row < nr; row++) {
+ const uint block_offset = (ib+row*nb) * sizeof_block_q8_0;
+ float sumq = 0.f;
+ for (int iq = 0; iq < NB_Q8_0; ++iq) {
+ const int8_t qs_iq = int8_t(inA[x + block_offset + SIZE_OF_D + NB_Q8_0*il + iq]);
+ sumq += qs_iq * yl[iq];
+ }
+ const float16_t d = u8BufToFloat16(inA, x + block_offset);
+ sumf[row] += sumq*d;
+ }
+
+ yb += NB_Q8_0 * nw;
+ }
+
+ for (int row = 0; row < nr; ++row) {
+ const float tot = subgroupAdd(sumf[row]);
+ if (subgroupElect() && first_row + row < pcs.ne01) {
+ out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row] = tot;
+ }
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_mul_mv_q_n.comp b/ggml/src/kompute-shaders/op_mul_mv_q_n.comp
new file mode 100644
index 00000000..440b5ab2
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_mul_mv_q_n.comp
@@ -0,0 +1,48 @@
+void main() {
+ // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
+ if (gl_SubgroupInvocationID > 31)
+ return;
+
+ const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT);
+
+ const uint r0 = gl_WorkGroupID.x;
+ const uint r1 = gl_WorkGroupID.y;
+ const uint im = gl_WorkGroupID.z;
+
+ const uint first_row = (r0 * gl_NumSubgroups + gl_SubgroupID) * N_ROWS;
+
+ const uint i12 = im%pcs.ne12;
+ const uint i13 = im/pcs.ne12;
+
+ const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
+
+ const uint x = offset0; // Based from inA without base offset
+ const uint y = r1*uint(pcs.ne10)+im*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
+
+ float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
+
+ const uint ix = gl_SubgroupInvocationID/2;
+ const uint il = (BLOCKS_IN_QUANT/4)*(gl_SubgroupInvocationID%2);
+
+ uint yb = y + ix * BLOCKS_IN_QUANT + il;
+
+ //debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n",
+ // gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize,
+ // gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z);
+
+ for (uint ib = ix; ib < nb; ib += 16) {
+ for (int row = 0; row < N_ROWS; row++) {
+ const uint block_index = x + ib + row * nb;
+ sumf[row] += block_q_n_dot_y(block_index, yb, il);
+ }
+
+ yb += BLOCKS_IN_QUANT * 16;
+ }
+
+ for (int row = 0; row < N_ROWS; ++row) {
+ const float tot = subgroupAdd(sumf[row]);
+ if (first_row + row < pcs.ne01 && subgroupElect()) {
+ out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = tot;
+ }
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_mul_mv_q_n_pre.comp b/ggml/src/kompute-shaders/op_mul_mv_q_n_pre.comp
new file mode 100644
index 00000000..7912b09a
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_mul_mv_q_n_pre.comp
@@ -0,0 +1,22 @@
+layout(local_size_x_id = 0) in;
+layout(local_size_y = 1) in;
+layout(local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
+layout (binding = 1) readonly buffer tensorInB { float inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+ uint inAOff;
+ uint inBOff;
+ uint outOff;
+ int ne00;
+ int ne01;
+ int ne02;
+ int ne10;
+ int ne12;
+ int ne0;
+ int ne1;
+ uint r2;
+ uint r3;
+} pcs;
diff --git a/ggml/src/kompute-shaders/op_norm.comp b/ggml/src/kompute-shaders/op_norm.comp
new file mode 100644
index 00000000..ad0c3c01
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_norm.comp
@@ -0,0 +1,84 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 256) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+ uint inOff;
+ uint outOff;
+ uint ne00;
+ uint nb01;
+ float eps;
+} pcs;
+
+shared float sum[gl_WorkGroupSize.x];
+
+void main() {
+ const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
+ // MEAN
+ // parallel sum
+ sum[gl_LocalInvocationID.x] = 0.0;
+ for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+ sum[gl_LocalInvocationID.x] += in_[x+i00];
+ }
+
+ // reduce
+ barrier();
+ memoryBarrierShared();
+ [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
+ if (gl_LocalInvocationID.x < i) {
+ sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
+ }
+ barrier();
+ memoryBarrierShared();
+ }
+
+ // broadcast
+ if (gl_LocalInvocationID.x == 0) {
+ sum[0] /= float(pcs.ne00);
+ }
+ barrier();
+ memoryBarrierShared();
+ const float mean = sum[0];
+
+ // recenter
+ const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
+ for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+ out_[y+i00] = in_[x+i00] - mean;
+ }
+
+ // VARIANCE
+ // parallel sum
+ sum[gl_LocalInvocationID.x] = 0.0;
+ for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+ sum[gl_LocalInvocationID.x] += out_[y+i00] * out_[y+i00];
+ }
+
+ // reduce
+ barrier();
+ memoryBarrierShared();
+ [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
+ if (gl_LocalInvocationID.x < i) {
+ sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
+ }
+ barrier();
+ memoryBarrierShared();
+ }
+
+ // broadcast
+ if (gl_LocalInvocationID.x == 0) {
+ sum[0] /= float(pcs.ne00);
+ }
+ barrier();
+ memoryBarrierShared();
+ const float variance = sum[0];
+
+ const float scale = 1.0f/sqrt(variance + pcs.eps);
+ for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+ out_[y+i00] *= scale;
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_relu.comp b/ggml/src/kompute-shaders/op_relu.comp
new file mode 100644
index 00000000..52a601fe
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_relu.comp
@@ -0,0 +1,21 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
+layout(push_constant) uniform PushConstants {
+ uint inOff;
+ uint outOff;
+} pcs;
+
+void main() {
+ const uint baseIndex = gl_WorkGroupID.x * 4;
+
+ for (uint x = 0; x < 4; x++) {
+ const uint i = baseIndex + x;
+ out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]);
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_rmsnorm.comp b/ggml/src/kompute-shaders/op_rmsnorm.comp
new file mode 100644
index 00000000..da658c16
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_rmsnorm.comp
@@ -0,0 +1,53 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 512) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+ uint inOff;
+ uint outOff;
+ uint ne00;
+ uint nb01;
+ float eps;
+} pcs;
+
+shared float sum[gl_WorkGroupSize.x];
+
+void main() {
+ const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
+
+ // parallel sum
+ sum[gl_LocalInvocationID.x] = 0.0;
+ for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+ sum[gl_LocalInvocationID.x] += in_[x+i00] * in_[x+i00];
+ }
+
+ // reduce
+ barrier();
+ memoryBarrierShared();
+ [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
+ if (gl_LocalInvocationID.x < i) {
+ sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
+ }
+ barrier();
+ memoryBarrierShared();
+ }
+
+ // broadcast
+ if (gl_LocalInvocationID.x == 0) {
+ sum[0] /= float(pcs.ne00);
+ }
+ barrier();
+ memoryBarrierShared();
+
+ const float scale = 1.0f/sqrt(sum[0] + pcs.eps);
+
+ const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
+ for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
+ out_[y+i00] = in_[x+i00] * scale;
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_rope_f16.comp b/ggml/src/kompute-shaders/op_rope_f16.comp
new file mode 100644
index 00000000..1a4058b3
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_rope_f16.comp
@@ -0,0 +1,73 @@
+#version 450
+
+#include "rope_common.comp"
+
+layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
+layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
+layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; };
+
+void main() {
+ const uint i3 = gl_WorkGroupID.z;
+ const uint i2 = gl_WorkGroupID.y;
+ const uint i1 = gl_WorkGroupID.x;
+
+ const bool is_neox = (pcs.mode & 2) != 0;
+
+ float corr_dims[2];
+ rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
+
+ const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
+
+ const int p = inB[pcs.inBOff + i2];
+
+ float theta = float(p);
+
+ if (!is_neox) {
+ for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
+ float cos_theta, sin_theta;
+ rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
+
+ theta *= theta_scale;
+
+ const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
+ const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
+
+ const float x0 = float(inA[src]);
+ const float x1 = float(inA[src+1]);
+
+ out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
+ out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
+ }
+ } else {
+ const float inv_ndims = -1.f/pcs.n_dims;
+ for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
+ const uint cur_rot = ic;
+
+ float cos_theta, sin_theta;
+ rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
+
+ theta *= theta_scale;
+
+ const uint i0 = ic/2;
+
+ const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
+ const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
+
+ const float x0 = float(inA[src]);
+ const float x1 = float(inA[src+pcs.n_dims/2]);
+
+ out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
+ out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
+ }
+
+ for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
+ const uint i0 = ic;
+
+ const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
+ const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
+
+ out_[dst_data + 0] = inA[src + 0];
+ out_[dst_data + 1] = inA[src + 1];
+ }
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_rope_f32.comp b/ggml/src/kompute-shaders/op_rope_f32.comp
new file mode 100644
index 00000000..65e03827
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_rope_f32.comp
@@ -0,0 +1,73 @@
+#version 450
+
+#include "rope_common.comp"
+
+layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
+layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
+layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
+
+void main() {
+ const uint i3 = gl_WorkGroupID.z;
+ const uint i2 = gl_WorkGroupID.y;
+ const uint i1 = gl_WorkGroupID.x;
+
+ const bool is_neox = (pcs.mode & 2) != 0;
+
+ float corr_dims[2];
+ rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
+
+ const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
+
+ const int p = inB[pcs.inBOff + i2];
+
+ float theta = float(p);
+
+ if (!is_neox) {
+ for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
+ float cos_theta, sin_theta;
+ rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
+
+ theta *= theta_scale;
+
+ const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
+ const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
+
+ const float x0 = inA[src];
+ const float x1 = inA[src+1];
+
+ out_[dst_data] = x0*cos_theta - x1*sin_theta;
+ out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
+ }
+ } else {
+ const float inv_ndims = -1.f/pcs.n_dims;
+ for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
+ const uint cur_rot = ic;
+
+ float cos_theta, sin_theta;
+ rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
+
+ theta *= theta_scale;
+
+ const uint i0 = ic/2;
+
+ const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
+ const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
+
+ const float x0 = inA[src];
+ const float x1 = inA[src+pcs.n_dims/2];
+
+ out_[dst_data] = x0*cos_theta - x1*sin_theta;
+ out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
+ }
+
+ for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
+ const uint i0 = ic;
+
+ const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
+ const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
+
+ out_[dst_data + 0] = inA[src + 0];
+ out_[dst_data + 1] = inA[src + 1];
+ }
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_scale.comp b/ggml/src/kompute-shaders/op_scale.comp
new file mode 100644
index 00000000..bdae2673
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_scale.comp
@@ -0,0 +1,19 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+ uint inOff;
+ uint outOff;
+ float scale;
+} pcs;
+
+void main() {
+ const uint i = gl_WorkGroupID.x;
+ out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
+}
diff --git a/ggml/src/kompute-shaders/op_scale_8.comp b/ggml/src/kompute-shaders/op_scale_8.comp
new file mode 100644
index 00000000..ada69754
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_scale_8.comp
@@ -0,0 +1,23 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+ uint inOff;
+ uint outOff;
+ float scale;
+} pcs;
+
+void main() {
+ const uint baseIndex = gl_WorkGroupID.x * 8;
+
+ for (uint x = 0; x < 8; x++) {
+ const uint i = baseIndex + x;
+ out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_silu.comp b/ggml/src/kompute-shaders/op_silu.comp
new file mode 100644
index 00000000..0fb8e4b7
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_silu.comp
@@ -0,0 +1,22 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
+layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
+layout(push_constant) uniform PushConstants {
+ uint inOff;
+ uint outOff;
+} pcs;
+
+void main() {
+ const uint baseIndex = gl_WorkGroupID.x * 4;
+
+ for (uint x = 0; x < 4; x++) {
+ const uint i = baseIndex + x;
+ const float y = in_[i + pcs.inOff];
+ out_[i + pcs.outOff] = y / (1.0 + exp(-y));
+ }
+}
diff --git a/ggml/src/kompute-shaders/op_softmax.comp b/ggml/src/kompute-shaders/op_softmax.comp
new file mode 100644
index 00000000..7bc9176c
--- /dev/null
+++ b/ggml/src/kompute-shaders/op_softmax.comp
@@ -0,0 +1,56 @@
+// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4)
+
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x_id = 0) in;
+
+layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
+layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
+layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
+
+layout(push_constant) uniform PushConstants {
+ uint inAOff;
+ uint inBOff;
+ uint outOff;
+ int ne00;
+ int ne01;
+ int ne02;
+ float scale;
+ int mask;
+} pcs;
+
+void main() {
+ if (gl_SubgroupInvocationID > 31)
+ return;
+
+ const uint i03 = gl_WorkGroupID.z;
+ const uint i02 = gl_WorkGroupID.y;
+ const uint i01 = gl_WorkGroupID.x;
+
+ const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
+ const uint psrc0 = extra_off + pcs.inAOff; // Based from inA
+ const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
+ const uint pdst = extra_off + pcs.outOff; // Based from out_
+
+ // parallel max
+ float localMax = uintBitsToFloat(0xFF800000);
+ for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
+ localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f));
+ }
+ float max_ = subgroupMax(localMax);
+
+ // parallel sum
+ float localSum = 0.0f;
+ for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
+ const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_);
+ localSum += exp_psrc0;
+ out_[pdst + i00] = exp_psrc0;
+ }
+
+ const float sum = subgroupAdd(localSum);
+ for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
+ out_[pdst + i00] /= sum;
+ }
+}
diff --git a/ggml/src/kompute-shaders/rope_common.comp b/ggml/src/kompute-shaders/rope_common.comp
new file mode 100644
index 00000000..7b9394cb
--- /dev/null
+++ b/ggml/src/kompute-shaders/rope_common.comp
@@ -0,0 +1,67 @@
+#include "common.comp"
+
+// TODO: use a local size of 32 or more (Metal uses 1024)
+layout(local_size_x = 1) in;
+
+layout (push_constant) uniform parameter {
+ uint inAOff;
+ uint inBOff;
+ uint outOff;
+ int n_dims;
+ int mode;
+ int n_ctx_orig;
+ float freq_base;
+ float freq_scale;
+ float ext_factor;
+ float attn_factor;
+ float beta_fast;
+ float beta_slow;
+ uint nb00;
+ uint nb01;
+ uint nb02;
+ uint nb03;
+ int ne0;
+ uint nb0;
+ uint nb1;
+ uint nb2;
+ uint nb3;
+} pcs;
+
+float rope_yarn_ramp(const float low, const float high, const float i0) {
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
+ return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+void rope_yarn(
+ float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale,
+ out float cos_theta, out float sin_theta
+) {
+ // Get n-d rotational scaling corrected for extrapolation
+ float theta_interp = freq_scale * theta_extrap;
+ float theta = theta_interp;
+ if (ext_factor != 0.0f) {
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+ // Get n-d magnitude scaling corrected for interpolation
+ mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
+ }
+ cos_theta = cos(theta) * mscale;
+ sin_theta = sin(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
+ return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base));
+}
+
+void rope_yarn_corr_dims(
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2]
+) {
+ // start and end correction dims
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
+}
diff --git a/ggml/src/llamafile/sgemm.cpp b/ggml/src/llamafile/sgemm.cpp
new file mode 100644
index 00000000..9d56af78
--- /dev/null
+++ b/ggml/src/llamafile/sgemm.cpp
@@ -0,0 +1,1028 @@
+// Copyright 2024 Mozilla Foundation
+//
+// Permission is hereby granted, free of charge, to any person obtaining
+// a copy of this software and associated documentation files (the
+// "Software"), to deal in the Software without restriction, including
+// without limitation the rights to use, copy, modify, merge, publish,
+// distribute, sublicense, and/or sell copies of the Software, and to
+// permit persons to whom the Software is furnished to do so, subject to
+// the following conditions:
+//
+// The above copyright notice and this permission notice shall be
+// included in all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+
+//
+// _ _ ___ _ _ ___
+// | |_(_)_ _ _ _| _ ) | /_\ / __|
+// | _| | ' \ || | _ \ |__ / _ \\__ \.
+// \__|_|_||_\_, |___/____/_/ \_\___/
+// |__/
+//
+// BASIC LINEAR ALGEBRA SUBPROGRAMS
+//
+//
+// This file implements multithreaded CPU matrix multiplication for the
+// common contiguous use case C = Aᵀ * B. These kernels are designed to
+// have excellent performance[1] for matrices that fit in the CPU cache
+// without imposing any overhead such as cache filling or malloc calls.
+//
+// This implementation does not guarantee any upper bound with rounding
+// errors, which grow along with k. Our goal's to maximally exploit the
+// hardware for performance, and then use whatever resources remain for
+// improving numerical accuracy.
+//
+// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
+// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Wpedantic"
+#pragma GCC diagnostic ignored "-Wignored-attributes"
+#endif
+
+#include "sgemm.h"
+#include "ggml-impl.h"
+#include "ggml-quants.h"
+
+#ifdef _MSC_VER
+#define NOINLINE __declspec(noinline)
+#else
+#define NOINLINE __attribute__((__noinline__))
+#endif
+
+#if defined(__ARM_NEON) || defined(__AVX512F__)
+#define VECTOR_REGISTERS 32
+#else
+#define VECTOR_REGISTERS 16
+#endif
+
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+
+namespace {
+
+inline float unhalf(ggml_fp16_t d) {
+ return GGML_FP16_TO_FP32(d);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// VECTORIZED ARITHMETIC OPERATIONS
+
+#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
+inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
+inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
+#endif // __SSE__
+
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
+inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
+inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
+#endif // __AVX__
+
+#if defined(__AVX512F__)
+inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
+inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
+inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
+#endif // __AVX512F__
+
+#if defined(__ARM_NEON)
+inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
+inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
+inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
+#endif // __ARM_NEON
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
+inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
+inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// VECTORIZED FUSED MULTIPLY ADD
+
+/**
+ * Computes a * b + c.
+ */
+template <typename T, typename U>
+inline U madd(T a, T b, U c) {
+ return add(mul(a, b), c);
+}
+
+#if defined(__FMA__)
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+template <>
+inline __m256 madd(__m256 a, __m256 b, __m256 c) {
+ return _mm256_fmadd_ps(a, b, c);
+}
+#endif
+#if defined(__AVX512F__)
+template <>
+inline __m512 madd(__m512 a, __m512 b, __m512 c) {
+ return _mm512_fmadd_ps(a, b, c);
+}
+#endif
+#endif
+
+#if defined(__ARM_FEATURE_FMA)
+template <>
+inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
+ return vfmaq_f32(c, b, a);
+}
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
+template <>
+inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
+ return vfmaq_f16(c, b, a);
+}
+#endif
+#endif
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// VECTORIZED HORIZONTAL SUM
+
+#if defined(__ARM_NEON)
+inline float hsum(float32x4_t x) {
+ return vaddvq_f32(x);
+}
+#endif // __ARM_NEON
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
+inline float hsum(float16x8_t x) {
+ return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
+ vcvt_f32_f16(vget_high_f16(x))));
+}
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+inline float hsum(__m128 x) {
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+ x = _mm_add_ps(x, _mm_movehl_ps(x, x));
+ x = _mm_add_ss(x, _mm_movehdup_ps(x));
+#else
+ __m128 t;
+ t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
+ x = _mm_add_ps(x, t);
+ t = _mm_movehl_ps(t, x);
+ x = _mm_add_ss(x, t);
+#endif
+ return _mm_cvtss_f32(x);
+}
+#endif
+
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+inline float hsum(__m256 x) {
+ return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
+ _mm256_castps256_ps128(x)));
+}
+#endif // __AVX__
+
+#if defined(__AVX512F__)
+inline float hsum(__m512 x) {
+ return _mm512_reduce_add_ps(x);
+}
+#endif // __AVX512F__
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// VECTORIZED MEMORY LOADING
+
+template <typename T, typename U> T load(const U *);
+
+#if defined(__ARM_NEON)
+template <> inline float32x4_t load(const float *p) {
+ return vld1q_f32(p);
+}
+#if !defined(_MSC_VER)
+template <> inline float16x8_t load(const ggml_fp16_t *p) {
+ return vld1q_f16((const float16_t *)p);
+}
+template <> inline float32x4_t load(const ggml_fp16_t *p) {
+ return vcvt_f32_f16(vld1_f16((const float16_t *)p));
+}
+#endif // _MSC_VER
+#endif // __ARM_NEON
+
+#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+template <> inline __m128 load(const float *p) {
+ return _mm_loadu_ps(p);
+}
+#endif // __SSE__
+
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+template <> inline __m256 load(const float *p) {
+ return _mm256_loadu_ps(p);
+}
+#endif // __AVX__
+
+#if defined(__F16C__)
+template <> inline __m256 load(const ggml_fp16_t *p) {
+ return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
+}
+#endif // __F16C__
+
+#if defined(__AVX512F__)
+template <> inline __m512 load(const float *p) {
+ return _mm512_loadu_ps(p);
+}
+template <> inline __m512 load(const ggml_fp16_t *p) {
+ return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
+}
+#endif // __AVX512F__
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// FLOATING POINT MATRIX MULTIPLICATION
+
+template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
+class tinyBLAS {
+ public:
+ tinyBLAS(int64_t k,
+ const TA *A, int64_t lda,
+ const TB *B, int64_t ldb,
+ TC *C, int64_t ldc,
+ int ith, int nth)
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
+ }
+
+ void matmul(int64_t m, int64_t n) {
+ mnpack(0, m, 0, n);
+ }
+
+ private:
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t mc, nc, mp, np;
+ switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
+#if VECTOR_REGISTERS == 32
+ case 0x55:
+ mc = 5;
+ nc = 5;
+ gemm<5, 5>(m0, m, n0, n);
+ break;
+ case 0x45:
+ mc = 4;
+ nc = 5;
+ gemm<4, 5>(m0, m, n0, n);
+ break;
+ case 0x54:
+ mc = 5;
+ nc = 4;
+ gemm<5, 4>(m0, m, n0, n);
+ break;
+ case 0x44:
+ mc = 4;
+ nc = 4;
+ gemm<4, 4>(m0, m, n0, n);
+ break;
+ case 0x53:
+ mc = 5;
+ nc = 3;
+ gemm<5, 3>(m0, m, n0, n);
+ break;
+ case 0x35:
+ mc = 3;
+ nc = 5;
+ gemm<3, 5>(m0, m, n0, n);
+ break;
+ case 0x43:
+ mc = 4;
+ nc = 3;
+ gemm<4, 3>(m0, m, n0, n);
+ break;
+#else
+ case 0x55:
+ case 0x54:
+ case 0x53:
+ case 0x45:
+ case 0x44:
+ case 0x43:
+ mc = 4;
+ nc = 3;
+ gemm<4, 3>(m0, m, n0, n);
+ break;
+ case 0x35:
+#endif
+ case 0x34:
+ mc = 3;
+ nc = 4;
+ gemm<3, 4>(m0, m, n0, n);
+ break;
+ case 0x52:
+ mc = 5;
+ nc = 2;
+ gemm<5, 2>(m0, m, n0, n);
+ break;
+ case 0x33:
+ mc = 3;
+ nc = 3;
+ gemm<3, 3>(m0, m, n0, n);
+ break;
+ case 0x25:
+ mc = 2;
+ nc = 5;
+ gemm<2, 5>(m0, m, n0, n);
+ break;
+ case 0x42:
+ mc = 4;
+ nc = 2;
+ gemm<4, 2>(m0, m, n0, n);
+ break;
+ case 0x24:
+ mc = 2;
+ nc = 4;
+ gemm<2, 4>(m0, m, n0, n);
+ break;
+ case 0x32:
+ mc = 3;
+ nc = 2;
+ gemm<3, 2>(m0, m, n0, n);
+ break;
+ case 0x23:
+ mc = 2;
+ nc = 3;
+ gemm<2, 3>(m0, m, n0, n);
+ break;
+ case 0x51:
+ mc = 5;
+ nc = 1;
+ gemm<5, 1>(m0, m, n0, n);
+ break;
+ case 0x41:
+ mc = 4;
+ nc = 1;
+ gemm<4, 1>(m0, m, n0, n);
+ break;
+ case 0x22:
+ mc = 2;
+ nc = 2;
+ gemm<2, 2>(m0, m, n0, n);
+ break;
+ case 0x15:
+ mc = 1;
+ nc = 5;
+ gemm<1, 5>(m0, m, n0, n);
+ break;
+ case 0x14:
+ mc = 1;
+ nc = 4;
+ gemm<1, 4>(m0, m, n0, n);
+ break;
+ case 0x31:
+ mc = 3;
+ nc = 1;
+ gemm<3, 1>(m0, m, n0, n);
+ break;
+ case 0x13:
+ mc = 1;
+ nc = 3;
+ gemm<1, 3>(m0, m, n0, n);
+ break;
+ case 0x21:
+ mc = 2;
+ nc = 1;
+ gemm<2, 1>(m0, m, n0, n);
+ break;
+ case 0x12:
+ mc = 1;
+ nc = 2;
+ gemm<1, 2>(m0, m, n0, n);
+ break;
+ case 0x11:
+ mc = 1;
+ nc = 1;
+ gemm<1, 1>(m0, m, n0, n);
+ break;
+ default:
+ return;
+ }
+ mp = m0 + (m - m0) / mc * mc;
+ np = n0 + (n - n0) / nc * nc;
+ mnpack(mp, m, n0, np);
+ mnpack(m0, m, np, n);
+ }
+
+ template <int RM, int RN>
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t ytiles = (m - m0) / RM;
+ int64_t xtiles = (n - n0) / RN;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
+ if (end > tiles)
+ end = tiles;
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = m0 + job / xtiles * RM;
+ int64_t jj = n0 + job % xtiles * RN;
+ D Cv[RN][RM] = {};
+ for (int64_t l = 0; l < k; l += KN)
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
+ Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
+ load<V>(B + ldb * (jj + j) + l),
+ Cv[j][i]);
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+ }
+ }
+
+ const TA *const A;
+ const TB *const B;
+ TC *const C;
+ const int64_t k;
+ const int64_t lda;
+ const int64_t ldb;
+ const int64_t ldc;
+ const int ith;
+ const int nth;
+};
+
+//////////////////////////////////////////////////////////////////////////////////////////
+// QUANT ZERO MATRIX MULTIPLICATION
+
+#if defined(__ARM_FEATURE_DOTPROD)
+template <typename TA>
+class tinyBLAS_Q0_ARM {
+ public:
+ tinyBLAS_Q0_ARM(int64_t k,
+ const TA *A, int64_t lda,
+ const block_q8_0 *B, int64_t ldb,
+ float *C, int64_t ldc,
+ int ith, int nth)
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
+ }
+
+ void matmul(int64_t m, int64_t n) {
+ mnpack(0, m, 0, n);
+ }
+
+ private:
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t mc, nc, mp, np;
+ switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
+ case 0x33:
+ mc = 3;
+ nc = 3;
+ gemm<3, 3>(m0, m, n0, n);
+ break;
+ case 0x32:
+ mc = 3;
+ nc = 2;
+ gemm<3, 2>(m0, m, n0, n);
+ break;
+ case 0x23:
+ mc = 2;
+ nc = 3;
+ gemm<2, 3>(m0, m, n0, n);
+ break;
+ case 0x22:
+ mc = 2;
+ nc = 2;
+ gemm<2, 2>(m0, m, n0, n);
+ break;
+ case 0x31:
+ mc = 3;
+ nc = 1;
+ gemm<3, 1>(m0, m, n0, n);
+ break;
+ case 0x13:
+ mc = 1;
+ nc = 3;
+ gemm<1, 3>(m0, m, n0, n);
+ break;
+ case 0x21:
+ mc = 2;
+ nc = 1;
+ gemm<2, 1>(m0, m, n0, n);
+ break;
+ case 0x12:
+ mc = 1;
+ nc = 2;
+ gemm<1, 2>(m0, m, n0, n);
+ break;
+ case 0x11:
+ mc = 1;
+ nc = 1;
+ gemm<1, 1>(m0, m, n0, n);
+ break;
+ default:
+ return;
+ }
+ mp = m0 + (m - m0) / mc * mc;
+ np = n0 + (n - n0) / nc * nc;
+ mnpack(mp, m, n0, np);
+ mnpack(m0, m, np, n);
+ }
+
+ template <int RM, int RN>
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t ytiles = (m - m0) / RM;
+ int64_t xtiles = (n - n0) / RN;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
+ if (end > tiles)
+ end = tiles;
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = m0 + job / xtiles * RM;
+ int64_t jj = n0 + job % xtiles * RN;
+ float32x4_t Cv[RN][RM] = {};
+ for (int64_t l = 0; l < k; ++l)
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
+ Cv[j][i] = vmlaq_n_f32(Cv[j][i],
+ vcvtq_f32_s32(vdotq_s32(
+ vdotq_s32(vdupq_n_s32(0),
+ load_lo(A + lda * (ii + i) + l),
+ load_lo(B + ldb * (jj + j) + l)),
+ load_hi(A + lda * (ii + i) + l),
+ load_hi(B + ldb * (jj + j) + l))),
+ unhalf(A[lda * (ii + i) + l].d) *
+ unhalf(B[ldb * (jj + j) + l].d));
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+ }
+ }
+
+ inline int8x16_t load_lo(const block_q8_0 *b) {
+ return vld1q_s8(b->qs);
+ }
+
+ inline int8x16_t load_hi(const block_q8_0 *b) {
+ return vld1q_s8(b->qs + 16);
+ }
+
+ inline int8x16_t load_lo(const block_q4_0 *b) {
+ return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
+ vdupq_n_u8(0x0f))),
+ vdupq_n_s8(0x8));
+ }
+
+ inline int8x16_t load_hi(const block_q4_0 *b) {
+ return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
+ vdupq_n_s8(0x8));
+ }
+
+ const TA *const A;
+ const block_q8_0 *const B;
+ float *const C;
+ const int64_t k;
+ const int64_t lda;
+ const int64_t ldb;
+ const int64_t ldc;
+ const int ith;
+ const int nth;
+};
+#endif // __ARM_FEATURE_DOTPROD
+
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+template <typename TA, typename TB, typename TC>
+class tinyBLAS_Q0_AVX {
+ public:
+ tinyBLAS_Q0_AVX(int64_t k,
+ const TA *A, int64_t lda,
+ const TB *B, int64_t ldb,
+ TC *C, int64_t ldc,
+ int ith, int nth)
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
+ }
+
+ void matmul(int64_t m, int64_t n) {
+ mnpack(0, m, 0, n);
+ }
+
+ private:
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t mc, nc, mp, np;
+ switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
+#if VECTOR_REGISTERS == 32
+ case 0x44:
+ mc = 4;
+ nc = 4;
+ gemm<4, 4>(m0, m, n0, n);
+ break;
+ case 0x43:
+ mc = 4;
+ nc = 3;
+ gemm<4, 3>(m0, m, n0, n);
+ break;
+ case 0x34:
+ mc = 3;
+ nc = 4;
+ gemm<3, 4>(m0, m, n0, n);
+ break;
+ case 0x33:
+ mc = 3;
+ nc = 3;
+ gemm<3, 3>(m0, m, n0, n);
+ break;
+ case 0x42:
+ mc = 4;
+ nc = 2;
+ gemm<4, 2>(m0, m, n0, n);
+ break;
+ case 0x24:
+ mc = 2;
+ nc = 4;
+ gemm<2, 4>(m0, m, n0, n);
+ break;
+#else
+ case 0x44:
+ case 0x43:
+ case 0x42:
+ mc = 4;
+ nc = 2;
+ gemm<4, 2>(m0, m, n0, n);
+ break;
+ case 0x34:
+ case 0x24:
+ mc = 2;
+ nc = 4;
+ gemm<2, 4>(m0, m, n0, n);
+ break;
+ case 0x33:
+#endif
+ case 0x32:
+ mc = 3;
+ nc = 2;
+ gemm<3, 2>(m0, m, n0, n);
+ break;
+ case 0x23:
+ mc = 2;
+ nc = 3;
+ gemm<2, 3>(m0, m, n0, n);
+ break;
+ case 0x41:
+ mc = 4;
+ nc = 1;
+ gemm<4, 1>(m0, m, n0, n);
+ break;
+ case 0x22:
+ mc = 2;
+ nc = 2;
+ gemm<2, 2>(m0, m, n0, n);
+ break;
+ case 0x14:
+ mc = 1;
+ nc = 4;
+ gemm<1, 4>(m0, m, n0, n);
+ break;
+ case 0x31:
+ mc = 3;
+ nc = 1;
+ gemm<3, 1>(m0, m, n0, n);
+ break;
+ case 0x13:
+ mc = 1;
+ nc = 3;
+ gemm<1, 3>(m0, m, n0, n);
+ break;
+ case 0x21:
+ mc = 2;
+ nc = 1;
+ gemm<2, 1>(m0, m, n0, n);
+ break;
+ case 0x12:
+ mc = 1;
+ nc = 2;
+ gemm<1, 2>(m0, m, n0, n);
+ break;
+ case 0x11:
+ mc = 1;
+ nc = 1;
+ gemm<1, 1>(m0, m, n0, n);
+ break;
+ default:
+ return;
+ }
+ mp = m0 + (m - m0) / mc * mc;
+ np = n0 + (n - n0) / nc * nc;
+ mnpack(mp, m, n0, np);
+ mnpack(m0, m, np, n);
+ }
+
+ template <int RM, int RN>
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t ytiles = (m - m0) / RM;
+ int64_t xtiles = (n - n0) / RN;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
+ if (end > tiles)
+ end = tiles;
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = m0 + job / xtiles * RM;
+ int64_t jj = n0 + job % xtiles * RN;
+ __m256 Cv[RN][RM] = {};
+ for (int64_t l = 0; l < k; ++l)
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i) {
+#if defined(__AVX2__)
+ __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+ load(A + lda * (ii + i) + l)),
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
+ load(A + lda * (ii + i) + l)));
+#else
+ __m128i ali0 = load0(A + lda * (ii + i) + l);
+ __m128i ali1 = load1(A + lda * (ii + i) + l);
+ __m128i blj0 = load0(B + ldb * (jj + j) + l);
+ __m128i blj1 = load1(B + ldb * (jj + j) + l);
+
+ __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
+ __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
+ __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
+ __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
+
+ // updot
+ const __m128i oneFill = _mm_set1_epi16(1);
+ __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
+ __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
+ __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
+#endif
+ Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
+ unhalf(B[ldb * (jj + j) + l].d)),
+ udTmp,
+ Cv[j][i]);
+ }
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < RM; ++i)
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+ }
+ }
+
+ inline __m256i load(const block_q8_0 *b) {
+ return _mm256_loadu_si256((const __m256i *)b->qs);
+ }
+
+ inline __m128i load0(const block_q8_0 *b) {
+ return _mm_loadu_si128((const __m128i *)b->qs);
+ }
+
+ inline __m128i load1(const block_q8_0 *b) {
+ return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
+ }
+
+ inline __m256i load(const block_q4_0 *b) {
+ return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
+ }
+
+ inline __m128i load0(const block_q4_0 *b) {
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+ return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
+ }
+
+ inline __m128i load1(const block_q4_0 *b) {
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
+ return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
+ }
+
+ inline __m256 updot(__m256i u, __m256i s) {
+ __m256i res;
+#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
+ res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
+#else
+ res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
+#endif
+ return _mm256_cvtepi32_ps(res);
+ }
+
+ static inline __m256i denibble(const uint8_t *p) {
+ __m128i x = _mm_loadu_si128((const __m128i *)p);
+ return _mm256_and_si256(_mm256_set1_epi8(15),
+ _mm256_insertf128_si256(_mm256_castsi128_si256(x),
+ _mm_srli_epi16(x, 4), 1));
+ }
+
+ const TA *const A;
+ const TB *const B;
+ TC *const C;
+ const int64_t k;
+ const int64_t lda;
+ const int64_t ldb;
+ const int64_t ldc;
+ const int ith;
+ const int nth;
+};
+#endif // __AVX__
+
+} // namespace
+
+/**
+ * Performs optimized matrix multiplication on CPU.
+ *
+ * This subroutine may compute C = Aᵀ * B with column major ordering.
+ * Despite its name, this isn't a generalized implementation. Work is
+ * only performed when a handwritten kernel is written and available.
+ * Otherwise the caller should fall back to a general matmul routine.
+ *
+ * For example, for single-threaded single-precision GEMM you can say
+ *
+ * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
+ * 0, 1,
+ * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
+ *
+ * @param m is rows in `A` and `C`
+ * @param n is cols in `B` and `C`
+ * @param k is cols in `A` and rows in `B`
+ * @param A is first input matrix (always transposed)
+ * @param lda is row stride of `A`
+ * @param B is second input matrix (never transposed)
+ * @param ldb is row stride of `B`
+ * @param C is input/output array of output matrices
+ * @param ldc is row stride of `C`
+ * @param ith is thread id (must be less than `nth`)
+ * @param nth is number of threads (must be greater than zero)
+ * @param Atype is GGML data type of `A`
+ * @param Btype is GGML data type of `B`
+ * @param Ctype is GGML data type of `C`
+ * @return true if this function was able to service the matmul request
+ */
+
+bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
+ int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
+
+ assert(m >= 0);
+ assert(n >= 0);
+ assert(k >= 0);
+ assert(lda >= k);
+ assert(ldb >= k);
+ assert(ldc >= m);
+ assert(nth > 0);
+ assert(ith < nth);
+
+ if (Ctype != GGML_TYPE_F32)
+ return false;
+
+ switch (Atype) {
+
+ case GGML_TYPE_F32: {
+ if (Btype != GGML_TYPE_F32)
+ return false;
+#if defined(__AVX512F__)
+ if (k % 16)
+ return false;
+ tinyBLAS<16, __m512, __m512, float, float, float> tb{
+ k, (const float *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#elif defined(__AVX__) || defined(__AVX2__)
+ if (k % 8)
+ return false;
+ tinyBLAS<8, __m256, __m256, float, float, float> tb{
+ k, (const float *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#elif defined(__ARM_NEON)
+ if (n < 4)
+ return false;
+ if (k % 4)
+ return false;
+ tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
+ k, (const float *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#else
+ return false;
+#endif
+ }
+
+ case GGML_TYPE_F16: {
+#if defined(__AVX512F__)
+ if (k % 16)
+ return false;
+ if (Btype != GGML_TYPE_F32)
+ return false;
+ tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
+ k, (const ggml_fp16_t *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
+ if (k % 8)
+ return false;
+ if (Btype != GGML_TYPE_F32)
+ return false;
+ tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
+ k, (const ggml_fp16_t *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
+ if (n < 8)
+ return false;
+ if (k % 8)
+ return false;
+ if (Btype != GGML_TYPE_F16)
+ return false;
+ tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
+ k, (const ggml_fp16_t *)A, lda,
+ (const ggml_fp16_t *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#elif defined(__ARM_NEON) && !defined(_MSC_VER)
+ if (k % 4)
+ return false;
+ if (Btype != GGML_TYPE_F32)
+ return false;
+ tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
+ k, (const ggml_fp16_t *)A, lda,
+ (const float *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#else
+ return false;
+#endif
+ }
+
+ case GGML_TYPE_Q8_0: {
+ if (Btype != GGML_TYPE_Q8_0)
+ return false;
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+ tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
+ k, (const block_q8_0 *)A, lda,
+ (const block_q8_0 *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#elif defined(__ARM_FEATURE_DOTPROD)
+ tinyBLAS_Q0_ARM<block_q8_0> tb{
+ k, (const block_q8_0 *)A, lda,
+ (const block_q8_0 *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#else
+ return false;
+#endif
+ }
+
+ case GGML_TYPE_Q4_0: {
+ if (Btype != GGML_TYPE_Q8_0)
+ return false;
+#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
+ tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
+ k, (const block_q4_0 *)A, lda,
+ (const block_q8_0 *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#elif defined(__ARM_FEATURE_DOTPROD)
+ tinyBLAS_Q0_ARM<block_q4_0> tb{
+ k, (const block_q4_0 *)A, lda,
+ (const block_q8_0 *)B, ldb,
+ (float *)C, ldc,
+ ith, nth};
+ tb.matmul(m, n);
+ return true;
+#else
+ return false;
+#endif
+ }
+
+ default:
+ return false;
+ }
+
+ (void)m;
+ (void)n;
+ (void)k;
+ (void)A;
+ (void)lda;
+ (void)B;
+ (void)ldb;
+ (void)C;
+ (void)ldc;
+ (void)ith;
+ (void)nth;
+ (void)Atype;
+ (void)Btype;
+ (void)Ctype;
+}
diff --git a/ggml/src/llamafile/sgemm.h b/ggml/src/llamafile/sgemm.h
new file mode 100644
index 00000000..caf6dd55
--- /dev/null
+++ b/ggml/src/llamafile/sgemm.h
@@ -0,0 +1,14 @@
+#pragma once
+#include <stdint.h>
+#include <stdbool.h>
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
+ const void *, int64_t, void *, int64_t, int, int,
+ int, int, int);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/src/vulkan-shaders/CMakeLists.txt b/ggml/src/vulkan-shaders/CMakeLists.txt
new file mode 100644
index 00000000..41551e00
--- /dev/null
+++ b/ggml/src/vulkan-shaders/CMakeLists.txt
@@ -0,0 +1,5 @@
+
+set(TARGET vulkan-shaders-gen)
+add_executable(${TARGET} vulkan-shaders-gen.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
diff --git a/ggml/src/vulkan-shaders/add.comp b/ggml/src/vulkan-shaders/add.comp
new file mode 100644
index 00000000..8475b011
--- /dev/null
+++ b/ggml/src/vulkan-shaders/add.comp
@@ -0,0 +1,12 @@
+#version 450
+
+#include "types.comp"
+#include "generic_binary_head.comp"
+
+void main() {
+ if (gl_GlobalInvocationID.x >= p.ne) {
+ return;
+ }
+
+ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) + FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
+}
diff --git a/ggml/src/vulkan-shaders/argsort.comp b/ggml/src/vulkan-shaders/argsort.comp
new file mode 100644
index 00000000..e55414b0
--- /dev/null
+++ b/ggml/src/vulkan-shaders/argsort.comp
@@ -0,0 +1,71 @@
+#version 450
+
+#include "types.comp"
+
+#define BLOCK_SIZE 1024
+#define ASC 0
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) buffer D {int data_d[];};
+
+layout (push_constant) uniform parameter {
+ uint ncols;
+ uint ncols_pad;
+ uint order;
+} p;
+
+shared int dst_row[BLOCK_SIZE];
+
+void swap(uint idx0, uint idx1) {
+ int tmp = dst_row[idx0];
+ dst_row[idx0] = dst_row[idx1];
+ dst_row[idx1] = tmp;
+}
+
+void main() {
+ // bitonic sort
+ const int col = int(gl_LocalInvocationID.x);
+ const uint row = gl_WorkGroupID.y;
+
+ if (col >= p.ncols_pad) {
+ return;
+ }
+
+ const uint row_offset = row * p.ncols;
+
+ // initialize indices
+ dst_row[col] = col;
+ barrier();
+
+ for (uint k = 2; k <= p.ncols_pad; k *= 2) {
+ for (uint j = k / 2; j > 0; j /= 2) {
+ const uint ixj = col ^ j;
+ if (ixj > col) {
+ if ((col & k) == 0) {
+ if (dst_row[col] >= p.ncols ||
+ (dst_row[ixj] < p.ncols && (p.order == ASC ?
+ data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] :
+ data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]]))
+ ) {
+ swap(col, ixj);
+ }
+ } else {
+ if (dst_row[ixj] >= p.ncols ||
+ (dst_row[col] < p.ncols && (p.order == ASC ?
+ data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] :
+ data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]]))
+ ) {
+ swap(col, ixj);
+ }
+ }
+ }
+ barrier();
+ }
+ }
+
+ if (col < p.ncols) {
+ data_d[row_offset + col] = dst_row[col];
+ }
+}
diff --git a/ggml/src/vulkan-shaders/clamp.comp b/ggml/src/vulkan-shaders/clamp.comp
new file mode 100644
index 00000000..ca272e22
--- /dev/null
+++ b/ggml/src/vulkan-shaders/clamp.comp
@@ -0,0 +1,13 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+void main() {
+ if (gl_GlobalInvocationID.x >= p.ne) {
+ return;
+ }
+
+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
+ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
+}
diff --git a/ggml/src/vulkan-shaders/copy.comp b/ggml/src/vulkan-shaders/copy.comp
new file mode 100644
index 00000000..efb55876
--- /dev/null
+++ b/ggml/src/vulkan-shaders/copy.comp
@@ -0,0 +1,16 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+void main() {
+ if (gl_GlobalInvocationID.x >= p.ne) {
+ return;
+ }
+
+#ifndef OPTIMIZATION_ERROR_WORKAROUND
+ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
+#else
+ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = data_a[src0_idx(gl_GlobalInvocationID.x)];
+#endif
+}
diff --git a/ggml/src/vulkan-shaders/dequant_f32.comp b/ggml/src/vulkan-shaders/dequant_f32.comp
new file mode 100644
index 00000000..a4d3fca5
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_f32.comp
@@ -0,0 +1,20 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {float data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.x * 16;
+
+ if (i >= p.nel) {
+ return;
+ }
+
+ [[unroll]] for (uint l = 0; l < 16; l++) {
+ data_b[i + l] = D_TYPE(data_a[i + l]);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_funcs.comp b/ggml/src/vulkan-shaders/dequant_funcs.comp
new file mode 100644
index 00000000..d5b98973
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_funcs.comp
@@ -0,0 +1,68 @@
+#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
+#endif
+
+#if defined(DATA_A_F32)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
+}
+#endif
+
+#if defined(DATA_A_F16)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
+}
+#endif
+
+#if defined(DATA_A_Q4_0)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const float d = float(data_a[a_offset + ib].d);
+ const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
+ return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
+}
+#endif
+
+#if defined(DATA_A_Q4_1)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const float d = float(data_a[a_offset + ib].d);
+ const float m = float(data_a[a_offset + ib].m);
+ const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
+ return vec2(vui & 0xF, vui >> 4) * d + m;
+}
+#endif
+
+#if defined(DATA_A_Q5_0)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const float d = float(data_a[a_offset + ib].d);
+ const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0];
+ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
+ const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
+ return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
+}
+#endif
+
+#if defined(DATA_A_Q5_1)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const float d = float(data_a[a_offset + ib].d);
+ const float m = float(data_a[a_offset + ib].m);
+ const uint uint_qh = data_a[a_offset + ib].qh;
+ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
+ const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
+ return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
+}
+#endif
+
+#if defined(DATA_A_Q8_0)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const float d = float(data_a[a_offset + ib].d);
+ return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
+}
+#endif
+
+#if defined(DATA_A_IQ4_NL)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const float d = float(data_a[a_offset + ib].d);
+ const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
+ return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
+}
+#endif
diff --git a/ggml/src/vulkan-shaders/dequant_head.comp b/ggml/src/vulkan-shaders/dequant_head.comp
new file mode 100644
index 00000000..8d806435
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_head.comp
@@ -0,0 +1,13 @@
+#extension GL_EXT_control_flow_attributes : require
+#extension GL_EXT_shader_16bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+ uint M;
+ uint K;
+ uint stride_a;
+ uint stride_b;
+ uint nel;
+} p;
+
+#include "types.comp"
diff --git a/ggml/src/vulkan-shaders/dequant_iq4_nl.comp b/ggml/src/vulkan-shaders/dequant_iq4_nl.comp
new file mode 100644
index 00000000..34ef3da3
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_iq4_nl.comp
@@ -0,0 +1,30 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_iq4_nl data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
+
+ const uint tid = gl_LocalInvocationID.x % 64;
+ const uint il = tid/32;
+ const uint ir = tid%32;
+ const uint ib = 32*i + ir;
+ if (ib >= p.nel / 32) {
+ return;
+ }
+
+ const uint q_idx = 8*il;
+ const uint b_idx = 1024*i + 32*ir + q_idx;
+
+ const float d = float(data_a[ib].d);
+
+ [[unroll]] for (uint l = 0; l < 8; ++l) {
+ data_b[b_idx + l + 0] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);
+ data_b[b_idx + l + 16] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_q2_k.comp b/ggml/src/vulkan-shaders/dequant_q2_k.comp
new file mode 100644
index 00000000..157154af
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_q2_k.comp
@@ -0,0 +1,34 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
+ const uint i = gl_WorkGroupID.x * 256 + wgy;
+ if (i >= p.M * p.K / QUANT_K) {
+ return;
+ }
+
+ const uint tid = gl_LocalInvocationID.x;
+ const uint ip = tid / 32;
+ const uint il = tid - 32 * ip;
+ const uint is = 8 * ip + il / 16;
+
+ const uint y_idx = i * QUANT_K + 128 * ip + il;
+
+ const uint ql_idx = 32 * ip + il;
+ const uint8_t qs = data_a[i].qs[32 * ip + il];
+
+ FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
+ FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
+ data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
+ data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
+ data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
+ data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_q3_k.comp b/ggml/src/vulkan-shaders/dequant_q3_k.comp
new file mode 100644
index 00000000..c17dd0d9
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_q3_k.comp
@@ -0,0 +1,42 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
+ const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
+ if (i >= p.M * p.K / QUANT_K) {
+ return;
+ }
+
+ const uint r = gl_LocalInvocationID.x / 4;
+ const uint tid = r / 2;
+ const uint is0 = r % 2;
+ const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4);
+ const uint n = tid / 4;
+ const uint j = tid - 4*n;
+
+ const uint8_t m = uint8_t(1 << (4*n + j));
+ const uint is = 8*n + 2*j + is0;
+ const uint shift = 2*j;
+
+ const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) :
+ is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) :
+ is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) :
+ (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4));
+ const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d);
+ const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32);
+
+ const uint y_idx = i * QUANT_K + 128 * n + 32 * j;
+ const uint qs_idx = 32*n;
+
+ for (uint l = l0; l < l0 + 4; ++l) {
+ data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4)));
+ }
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_q4_0.comp b/ggml/src/vulkan-shaders/dequant_q4_0.comp
new file mode 100644
index 00000000..40818532
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_q4_0.comp
@@ -0,0 +1,30 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_q4_0 data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
+
+ const uint tid = gl_LocalInvocationID.x % 64;
+ const uint il = tid/32;
+ const uint ir = tid%32;
+ const uint ib = 32*i + ir;
+ if (ib >= p.nel / 32) {
+ return;
+ }
+
+ const uint q_idx = 8*il;
+ const uint b_idx = 1024*i + 32*ir + q_idx;
+
+ const float d = float(data_a[ib].d);
+
+ [[unroll]] for (uint l = 0; l < 8; ++l) {
+ data_b[b_idx + l + 0] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] & 0xF) - 8.0f));
+ data_b[b_idx + l + 16] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] >> 4) - 8.0f));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_q4_1.comp b/ggml/src/vulkan-shaders/dequant_q4_1.comp
new file mode 100644
index 00000000..2f27eee6
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_q4_1.comp
@@ -0,0 +1,32 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_q4_1 data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
+
+ const uint tid = gl_LocalInvocationID.x % 64;
+ const uint il = tid/32;
+ const uint ir = tid%32;
+ const uint ib = 32*i + ir;
+ if (ib >= p.nel / 32) {
+ return;
+ }
+
+ const uint b_idx = 1024*i + 32*ir + 8*il;
+
+ const float d = float(data_a[ib].d);
+ const float m = float(data_a[ib].m);
+
+ const uint q_idx = 8*il;
+
+ [[unroll]] for (uint l = 0; l < 8; ++l) {
+ data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m);
+ data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_q4_k.comp b/ggml/src/vulkan-shaders/dequant_q4_k.comp
new file mode 100644
index 00000000..92acb754
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_q4_k.comp
@@ -0,0 +1,56 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
+ const uint i = gl_WorkGroupID.x * 256 + wgy;
+ if (i >= p.M * p.K / QUANT_K) {
+ return;
+ }
+
+ const uint tid = gl_LocalInvocationID.x;
+ const uint il = tid / 8;
+ const uint ir = tid % 8;
+ const uint is = 2 * il;
+ const uint n = 4;
+
+ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
+
+ const uint y_idx = i * QUANT_K + 64 * il + n * ir;
+ const uint qs_idx = 32*il + n * ir;
+
+ uint8_t sc;
+ uint8_t m;
+ if (is < 4) {
+ sc = uint8_t(data_a[i].scales[is] & 63);
+ m = uint8_t(data_a[i].scales[is + 4] & 63);
+ } else {
+ sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4));
+ m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4));
+ }
+ const FLOAT_TYPE d1 = dall * sc;
+ const FLOAT_TYPE m1 = dmin * m;
+
+ if (is < 4) {
+ sc = uint8_t(data_a[i].scales[is + 1] & 63);
+ m = uint8_t(data_a[i].scales[is + 5] & 63);
+ } else {
+ sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4));
+ m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4));
+ }
+ const FLOAT_TYPE d2 = dall * sc;
+ const FLOAT_TYPE m2 = dmin * m;
+
+ [[unroll]] for (uint l = 0; l < n; ++l) {
+ data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1);
+ data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >> 4) - m2);
+ }
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_q5_0.comp b/ggml/src/vulkan-shaders/dequant_q5_0.comp
new file mode 100644
index 00000000..b20b8052
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_q5_0.comp
@@ -0,0 +1,34 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_q5_0 data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
+
+ const uint tid = gl_LocalInvocationID.x % 64;
+ const uint il = tid/32;
+ const uint ir = tid%32;
+ const uint ib = 32*i + ir;
+ if (ib >= p.nel / 32) {
+ return;
+ }
+
+ const uint b_idx = 1024*i + 32*ir + 8*il;
+
+ const float d = float(data_a[ib].d);
+ const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
+
+ const uint q_idx = 8*il;
+
+ [[unroll]] for (uint l = 0; l < 8; ++l) {
+ const uint iqs = q_idx + l;
+ const uint vui = uint(data_a[ib].qs[iqs]);
+ data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f));
+ data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_q5_1.comp b/ggml/src/vulkan-shaders/dequant_q5_1.comp
new file mode 100644
index 00000000..dc59fe3b
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_q5_1.comp
@@ -0,0 +1,35 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_q5_1 data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
+
+ const uint tid = gl_LocalInvocationID.x % 64;
+ const uint il = tid/32;
+ const uint ir = tid%32;
+ const uint ib = 32*i + ir;
+ if (ib >= p.nel / 32) {
+ return;
+ }
+
+ const uint b_idx = 1024*i + 32*ir + 8*il;
+
+ const float d = float(data_a[ib].d);
+ const float m = float(data_a[ib].m);
+ const uint qh = data_a[ib].qh;
+
+ const uint q_idx = 8*il;
+
+ [[unroll]] for (uint l = 0; l < 8; ++l) {
+ const uint iqs = q_idx + l;
+ const uint vui = uint(data_a[ib].qs[iqs]);
+ data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m);
+ data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_q5_k.comp b/ggml/src/vulkan-shaders/dequant_q5_k.comp
new file mode 100644
index 00000000..f314a76d
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_q5_k.comp
@@ -0,0 +1,58 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
+ const uint i = gl_WorkGroupID.x * 256 + wgy;
+ if (i >= p.M * p.K / QUANT_K) {
+ return;
+ }
+
+ const uint tid = gl_LocalInvocationID.x;
+ const uint il = tid / 16;
+ const uint ir = tid % 16;
+ const uint is = 2 * il;
+
+ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
+
+ const uint y_idx = i * QUANT_K + 64 * il + 2 * ir;
+ const uint qs_idx = 32*il + 2 * ir;
+ const uint qh_idx = 2 * ir;
+
+ uint8_t sc;
+ uint8_t m;
+ if (is < 4) {
+ sc = uint8_t(data_a[i].scales[is] & 63);
+ m = uint8_t(data_a[i].scales[is + 4] & 63);
+ } else {
+ sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4));
+ m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4));
+ }
+ const FLOAT_TYPE d1 = dall * sc;
+ const FLOAT_TYPE m1 = dmin * m;
+
+ if (is < 4) {
+ sc = uint8_t(data_a[i].scales[is + 1] & 63);
+ m = uint8_t(data_a[i].scales[is + 5] & 63);
+ } else {
+ sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4));
+ m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4));
+ }
+ const FLOAT_TYPE d2 = dall * sc;
+ const FLOAT_TYPE m2 = dmin * m;
+
+ const uint8_t hm1 = uint8_t(1 << (2 * il ));
+ const uint8_t hm2 = uint8_t(1 << (2 * il + 1));
+ data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx ] & 0xF) + (((data_a[i].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1);
+ data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] & 0xF) + (((data_a[i].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1);
+ data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx ] >> 4) + (((data_a[i].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2);
+ data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] >> 4) + (((data_a[i].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_q6_k.comp b/ggml/src/vulkan-shaders/dequant_q6_k.comp
new file mode 100644
index 00000000..0b913175
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_q6_k.comp
@@ -0,0 +1,33 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
+ const uint i = gl_WorkGroupID.x * 256 + wgy;
+ if (i >= p.M * p.K / QUANT_K) {
+ return;
+ }
+ const uint tid = gl_LocalInvocationID.x;
+ const uint ip = tid / 32;
+ const uint il = tid - 32 * ip;
+ const uint is = 8 * ip + il / 16;
+
+ const uint y_idx = i * QUANT_K + 128 * ip + il;
+
+ const uint ql_idx = 64 * ip + il;
+ const uint8_t qh = data_a[i].qh[32 * ip + il];
+
+ const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d);
+
+ data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
+ data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
+ data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
+ data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/dequant_q8_0.comp b/ggml/src/vulkan-shaders/dequant_q8_0.comp
new file mode 100644
index 00000000..bd1344a8
--- /dev/null
+++ b/ggml/src/vulkan-shaders/dequant_q8_0.comp
@@ -0,0 +1,31 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_q8_0 data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
+
+ const uint tid = gl_LocalInvocationID.x % 64;
+ const uint il = tid/32;
+ const uint ir = tid%32;
+ const uint ib = 32*i + ir;
+ if (ib >= p.nel / 32) {
+ return;
+ }
+
+ const uint b_idx = 1024*i + 32*ir + 16*il;
+
+ const float d = float(data_a[ib].d);
+
+ const uint q_idx = 16*il;
+
+ [[unroll]] for (uint l = 0; l < 16; l += 2) {
+ data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]);
+ data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/diag_mask_inf.comp b/ggml/src/vulkan-shaders/diag_mask_inf.comp
new file mode 100644
index 00000000..4e68742b
--- /dev/null
+++ b/ggml/src/vulkan-shaders/diag_mask_inf.comp
@@ -0,0 +1,34 @@
+#version 450
+
+#extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_control_flow_attributes : enable
+
+layout (push_constant) uniform parameter
+{
+ uint ncols;
+ uint rows_per_channel;
+ uint n_past;
+} p;
+
+#include "types.comp"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint col = gl_GlobalInvocationID.y;
+ const uint row = gl_GlobalInvocationID.x;
+
+ if (col >= p.ncols) {
+ return;
+ }
+
+ const uint i = row*p.ncols + col;
+ if (col > p.n_past + row % p.rows_per_channel) {
+ data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000));
+ } else {
+ data_d[i] = D_TYPE(data_a[i]);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/div.comp b/ggml/src/vulkan-shaders/div.comp
new file mode 100644
index 00000000..8ee4bfc7
--- /dev/null
+++ b/ggml/src/vulkan-shaders/div.comp
@@ -0,0 +1,12 @@
+#version 450
+
+#include "types.comp"
+#include "generic_binary_head.comp"
+
+void main() {
+ if (gl_GlobalInvocationID.x >= p.ne) {
+ return;
+ }
+
+ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) / FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
+}
diff --git a/ggml/src/vulkan-shaders/gelu.comp b/ggml/src/vulkan-shaders/gelu.comp
new file mode 100644
index 00000000..9fe807cc
--- /dev/null
+++ b/ggml/src/vulkan-shaders/gelu.comp
@@ -0,0 +1,25 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const float GELU_COEF_A = 0.044715f;
+ const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+ const uint i = gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ const float xi = float(data_a[i]);
+ const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
+ data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)));
+}
diff --git a/ggml/src/vulkan-shaders/generic_binary_head.comp b/ggml/src/vulkan-shaders/generic_binary_head.comp
new file mode 100644
index 00000000..ab45d256
--- /dev/null
+++ b/ggml/src/vulkan-shaders/generic_binary_head.comp
@@ -0,0 +1,48 @@
+#extension GL_EXT_shader_16bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+ uint ne;
+ uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
+ uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
+ uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
+ uint d_offset;
+ float param1; float param2;
+} p;
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+
+uint src0_idx(uint idx) {
+ const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
+ const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
+ const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
+ const uint i02_offset = i02*p.ne01*p.ne00;
+ const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
+ const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+ return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
+}
+
+uint src1_idx(uint idx) {
+ const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
+ const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
+ const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
+ const uint i02_offset = i02*p.ne01*p.ne00;
+ const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
+ const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+
+ return (i03 % p.ne13)*p.nb13 + (i02 % p.ne12)*p.nb12 + (i01 % p.ne11)*p.nb11 + (i00 % p.ne10)*p.nb10;
+}
+
+uint dst_idx(uint idx) {
+ const uint i23 = idx / (p.ne22*p.ne21*p.ne20);
+ const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
+ const uint i22 = (idx - i23_offset) / (p.ne21*p.ne20);
+ const uint i22_offset = i22*p.ne21*p.ne20;
+ const uint i21 = (idx - i23_offset - i22_offset) / p.ne20;
+ const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20;
+ return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20;
+}
diff --git a/ggml/src/vulkan-shaders/generic_head.comp b/ggml/src/vulkan-shaders/generic_head.comp
new file mode 100644
index 00000000..66e46ae6
--- /dev/null
+++ b/ggml/src/vulkan-shaders/generic_head.comp
@@ -0,0 +1,9 @@
+#extension GL_EXT_shader_16bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+ uint KX;
+ uint KY;
+ float param1;
+ float param2;
+} p;
diff --git a/ggml/src/vulkan-shaders/generic_unary_head.comp b/ggml/src/vulkan-shaders/generic_unary_head.comp
new file mode 100644
index 00000000..de08de7c
--- /dev/null
+++ b/ggml/src/vulkan-shaders/generic_unary_head.comp
@@ -0,0 +1,35 @@
+#extension GL_EXT_shader_16bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+ uint ne;
+ uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
+ uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
+ uint d_offset;
+ float param1; float param2;
+} p;
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+uint src0_idx(uint idx) {
+ const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
+ const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
+ const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
+ const uint i02_offset = i02*p.ne01*p.ne00;
+ const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
+ const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+ return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
+}
+
+uint dst_idx(uint idx) {
+ const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
+ const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
+ const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);
+ const uint i12_offset = i12*p.ne11*p.ne10;
+ const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;
+ const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
+ return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
+}
diff --git a/ggml/src/vulkan-shaders/get_rows.comp b/ggml/src/vulkan-shaders/get_rows.comp
new file mode 100644
index 00000000..e9ff22ef
--- /dev/null
+++ b/ggml/src/vulkan-shaders/get_rows.comp
@@ -0,0 +1,26 @@
+#version 450
+
+#include "types.comp"
+#include "generic_binary_head.comp"
+
+void main() {
+ const uint i00 = gl_GlobalInvocationID.x;
+ const uint i10 = gl_GlobalInvocationID.y;
+ const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
+ const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
+
+ if (i00 >= p.ne00) {
+ return;
+ }
+
+ const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
+
+ const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
+ const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
+
+#ifndef OPTIMIZATION_ERROR_WORKAROUND
+ data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
+#else
+ data_d[d_offset + i00] = data_a[a_offset + i00];
+#endif
+}
diff --git a/ggml/src/vulkan-shaders/get_rows_quant.comp b/ggml/src/vulkan-shaders/get_rows_quant.comp
new file mode 100644
index 00000000..53a9a96f
--- /dev/null
+++ b/ggml/src/vulkan-shaders/get_rows_quant.comp
@@ -0,0 +1,31 @@
+#version 450
+
+#include "types.comp"
+#include "generic_binary_head.comp"
+#include "dequant_funcs.comp"
+
+void main() {
+ const uint i00 = (gl_GlobalInvocationID.x)*2;
+ const uint i10 = gl_GlobalInvocationID.y;
+ const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
+ const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
+
+ if (i00 >= p.ne00) {
+ return;
+ }
+
+ const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
+
+ const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
+ const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
+
+ const uint ib = a_offset + i00/QUANT_K; // block index
+ const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
+ const uint iybs = i00 - i00%QUANT_K; // dst block start index
+ const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
+
+ vec2 v = dequantize(ib, iqs, 0);
+
+ data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
+ data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
+}
diff --git a/ggml/src/vulkan-shaders/mul.comp b/ggml/src/vulkan-shaders/mul.comp
new file mode 100644
index 00000000..bbb0aa1d
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul.comp
@@ -0,0 +1,12 @@
+#version 450
+
+#include "types.comp"
+#include "generic_binary_head.comp"
+
+void main() {
+ if (gl_GlobalInvocationID.x >= p.ne) {
+ return;
+ }
+
+ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(data_b[src1_idx(gl_GlobalInvocationID.x)]));
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_split_k_reduce.comp b/ggml/src/vulkan-shaders/mul_mat_split_k_reduce.comp
new file mode 100644
index 00000000..825b9103
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_split_k_reduce.comp
@@ -0,0 +1,29 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {float data_a[];};
+layout (binding = 1) writeonly buffer D {float data_d[];};
+
+layout (push_constant) uniform parameter {
+ uint ne;
+ uint k_num;
+} p;
+
+void main() {
+ const uint idx = gl_GlobalInvocationID.x;
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ float result = 0.0f;
+
+ [[unroll]] for (uint i = 0; i < p.k_num; i++) {
+ result += data_a[i * p.ne + idx];
+ }
+
+ data_d[idx] = result;
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec.comp b/ggml/src/vulkan-shaders/mul_mat_vec.comp
new file mode 100644
index 00000000..15d2a806
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec.comp
@@ -0,0 +1,50 @@
+#version 450
+
+#ifdef FLOAT16
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#endif
+
+#include "mul_mat_vec_base.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+
+shared FLOAT_TYPE tmp[BLOCK_SIZE];
+
+void main() {
+ const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
+ const uint tid = gl_LocalInvocationID.x;
+
+ uint a_offset, b_offset, d_offset;
+ get_offsets(a_offset, b_offset, d_offset);
+
+ const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
+
+ tmp[tid] = FLOAT_TYPE(0.0f);
+
+ [[unroll]] for (uint i = 0; i < p.ncols/BLOCK_SIZE; i += 2) {
+ const uint col = i*BLOCK_SIZE + 2*tid;
+ const uint ib = (row*p.ncols + col)/QUANT_K; // block index
+ const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
+ const uint iybs = col - col%QUANT_K; // y block start index
+
+ vec2 v = dequantize(ib, iqs, a_offset / QUANT_K);
+
+ // matrix multiplication
+ tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[b_offset + iybs + iqs]) +
+ FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
+ }
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
+ if (tid < s) {
+ tmp[tid] += tmp[tid + s];
+ }
+ barrier();
+ }
+ if (tid == 0) {
+ data_d[d_offset + row] = D_TYPE(tmp[0]);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/vulkan-shaders/mul_mat_vec_base.comp
new file mode 100644
index 00000000..5920bc93
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_base.comp
@@ -0,0 +1,81 @@
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_shader_8bit_storage : require
+
+#define K_QUANTS_PER_ITERATION 2
+
+#ifdef MUL_MAT_ID
+#define EXPERT_COUNT 8
+#endif
+
+#include "types.comp"
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+#ifdef MUL_MAT_ID
+layout (binding = 3) readonly buffer IDS {int data_ids[];};
+#endif
+
+#include "dequant_funcs.comp"
+
+layout (push_constant) uniform parameter
+{
+ uint ncols;
+ uint stride_a;
+ uint stride_b;
+ uint stride_d;
+
+ uint batch_stride_a;
+ uint batch_stride_b;
+ uint batch_stride_d;
+
+#ifdef MUL_MAT_ID
+ uint nei0;
+ uint ne11;
+#else
+ uint ne02;
+ uint ne12;
+ uint broadcast2;
+ uint broadcast3;
+#endif
+} p;
+
+void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
+#ifdef MUL_MAT_ID
+ const uint expert_idx = gl_GlobalInvocationID.y;
+#else
+ const uint batch_idx = gl_GlobalInvocationID.y;
+#endif
+
+#ifndef MUL_MAT_ID
+ const uint i13 = batch_idx / p.ne12;
+ const uint i12 = batch_idx % p.ne12;
+
+ const uint i03 = i13 / p.broadcast3;
+ const uint i02 = i12 / p.broadcast2;
+
+ const uint batch_idx_a = i03 * p.ne02 + i02;
+#else
+ const uint expert_id = data_ids[expert_idx];
+#endif
+
+ a_offset =
+#ifdef MUL_MAT_ID
+ expert_id * p.batch_stride_a;
+#else
+ batch_idx_a * p.batch_stride_a;
+#endif
+ b_offset =
+#ifdef MUL_MAT_ID
+ (expert_idx % p.ne11) * p.stride_b;
+#else
+ batch_idx * p.batch_stride_b;
+#endif
+ d_offset =
+#ifdef MUL_MAT_ID
+ expert_idx * p.stride_d;
+#else
+ batch_idx * p.batch_stride_d;
+#endif
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_nc.comp b/ggml/src/vulkan-shaders/mul_mat_vec_nc.comp
new file mode 100644
index 00000000..cb3f3c0d
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_nc.comp
@@ -0,0 +1,71 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+
+#define BLOCK_SIZE 32
+#define FLOAT_TYPE float
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
+layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
+
+layout (push_constant) uniform parameter
+{
+ uint ncols_x;
+ uint nrows_x;
+ uint row_stride_x;
+ uint channel_stride_x;
+ uint channel_x_divisor;
+ uint b_offset;
+ uint d_offset;
+} p;
+
+shared FLOAT_TYPE tmp[BLOCK_SIZE];
+
+void main() {
+ const uint tid = gl_LocalInvocationID.x;
+ const uint row_x = gl_GlobalInvocationID.y;
+ const uint channel = gl_GlobalInvocationID.z;
+ const uint channel_x = channel / p.channel_x_divisor;
+
+ const uint nrows_y = p.ncols_x;
+ const uint nrows_dst = p.nrows_x;
+ const uint row_dst = row_x;
+
+ const uint idst = channel*nrows_dst + row_dst;
+
+ tmp[tid] = 0.0f;
+
+ for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
+ const uint col_x = col_x0 + tid;
+
+ if (col_x >= p.ncols_x) {
+ break;
+ }
+
+ const uint row_y = col_x;
+
+ const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
+ const uint iy = channel*nrows_y + row_y;
+
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
+
+ tmp[tid] += xi * FLOAT_TYPE(data_b[iy]);
+ }
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ tmp[tid] += tmp[tid + s];
+ }
+ barrier();
+ }
+
+ if (tid == 0) {
+ dst[idst] = tmp[0];
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_p021.comp b/ggml/src/vulkan-shaders/mul_mat_vec_p021.comp
new file mode 100644
index 00000000..4b1871ca
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_p021.comp
@@ -0,0 +1,73 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+
+#define BLOCK_SIZE 32
+#define FLOAT_TYPE float
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
+layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
+
+layout (push_constant) uniform parameter
+{
+ uint ncols_x;
+ uint nrows_x;
+ uint nchannels_x;
+ uint nchannels_y;
+ uint b_offset;
+ uint d_offset;
+} p;
+
+shared FLOAT_TYPE tmp[BLOCK_SIZE];
+
+void main() {
+ const uint tid = gl_LocalInvocationID.x;
+ const uint row_x = gl_GlobalInvocationID.y;
+ const uint channel = gl_GlobalInvocationID.z;
+ const uint channel_x = channel / (p.nchannels_y / p.nchannels_x);
+
+ const uint nrows_y = p.ncols_x;
+ const uint nrows_dst = p.nrows_x;
+ const uint row_dst = row_x;
+
+ tmp[tid] = FLOAT_TYPE(0.0f);
+
+ for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
+ const uint col_x = col_x0 + tid;
+
+ if (col_x >= p.ncols_x) {
+ break;
+ }
+
+ // x is transposed and permuted
+ const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
+
+ const uint row_y = col_x;
+
+ // y is not transposed but permuted
+ const uint iy = channel*nrows_y + row_y;
+
+ tmp[tid] += xi * FLOAT_TYPE(data_b[iy]);
+ }
+
+ // dst is not transposed and not permuted
+ const uint idst = channel*nrows_dst + row_dst;
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ tmp[tid] += tmp[tid + s];
+ }
+ barrier();
+ }
+
+ if (tid == 0) {
+ dst[idst] = tmp[0];
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp
new file mode 100644
index 00000000..4cd97799
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp
@@ -0,0 +1,73 @@
+#version 450
+
+#include "mul_mat_vec_base.comp"
+
+layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+
+shared FLOAT_TYPE tmp[32];
+
+void main() {
+ const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
+
+ uint a_offset, b_offset, d_offset;
+ get_offsets(a_offset, b_offset, d_offset);
+
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+ const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
+
+ const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
+
+ const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
+
+ const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const uint v_in = tid - step*v_im; // 0...15 or 0...7
+
+ const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
+ const uint q_offset = 32*v_im + l0;
+ const uint s_offset = 8*v_im;
+ const uint y_offset = 128*v_im + l0;
+
+ tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
+
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+ const uint y_idx = i * QUANT_K + y_offset;
+
+ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
+
+ FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
+ FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+ sum1 += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3);
+ sum2 += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF);
+ }
+ tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;
+ }
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
+ if (tid < s) {
+ tmp[tid] += tmp[tid + s];
+ }
+ barrier();
+ }
+ if (tid == 0) {
+ data_d[d_offset + row] = D_TYPE(tmp[0]);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp
new file mode 100644
index 00000000..a6e430ea
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp
@@ -0,0 +1,66 @@
+#version 450
+
+#include "mul_mat_vec_base.comp"
+
+layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+
+shared FLOAT_TYPE tmp[32];
+
+void main() {
+ const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
+
+ uint a_offset, b_offset, d_offset;
+ get_offsets(a_offset, b_offset, d_offset);
+
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+ const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
+
+ const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
+
+ const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
+
+ const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const uint v_in = tid - step*v_im; // 0...15 or 0...7
+
+ const uint8_t m = uint8_t(1 << (4 * v_im));
+
+ const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
+ const uint q_offset = 32*v_im + l0;
+ const uint y_offset = 128*v_im + l0;
+
+ tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
+
+ const uint s_shift = 4 * v_im;
+
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+ const uint y_idx = i * QUANT_K + y_offset;
+
+ const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
+
+ FLOAT_TYPE sum = FLOAT_TYPE(0.0);
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+ sum += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4))
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4))
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4))
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4))
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4))
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4))
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4))
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4));
+ }
+ tmp[16 * ix + tid] += d * sum;
+ }
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
+ if (tid < s) {
+ tmp[tid] += tmp[tid + s];
+ }
+ barrier();
+ }
+ if (tid == 0) {
+ data_d[d_offset + row] = D_TYPE(tmp[0]);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp
new file mode 100644
index 00000000..75569363
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp
@@ -0,0 +1,115 @@
+#version 450
+
+#include "mul_mat_vec_base.comp"
+
+layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+
+shared FLOAT_TYPE tmp[32];
+
+void main() {
+ const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
+
+ uint a_offset, b_offset, d_offset;
+ get_offsets(a_offset, b_offset, d_offset);
+
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+ const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
+
+ const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
+
+ const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
+
+ const uint il = tid/step; // 0...3
+ const uint ir = tid - step*il; // 0...7 or 0...3
+ const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
+
+ const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+ const uint v_in = il % 2;
+
+ const uint l0 = n * (2 * ir + v_in); // 0...15
+ const uint q_offset = 32*v_im + l0;
+ const uint y_offset = 64*v_im + l0;
+
+ tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
+
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+ const uint y1_idx = i * QUANT_K + y_offset;
+ const uint y2_idx = y1_idx + 128;
+
+ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
+
+ const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
+ const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
+ const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
+ const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
+ const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
+ const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
+ const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
+ const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
+
+#if K_QUANTS_PER_ITERATION == 2
+ const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
+ const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
+ const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf);
+ const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] & 0xf);
+ const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
+ const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
+ const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] >> 4);
+ const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] >> 4);
+ const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
+ const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
+ const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] & 0xf);
+ const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] & 0xf);
+ const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
+ const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
+ const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] >> 4);
+ const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] >> 4);
+
+ const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1 + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3);
+ const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_5 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7);
+ const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx]) * q4_8 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_9 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * q4_10 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11);
+ const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_12 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_13 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * q4_14 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15);
+ const FLOAT_TYPE smin = FLOAT_TYPE(
+ FLOAT_TYPE(data_b[b_offset + y1_idx ]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx ]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7
+ + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7
+ + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * sc7
+ + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7
+ );
+ tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
+#else
+ const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
+ const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
+ const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
+ const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
+ const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
+ const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
+ const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
+ const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
+
+ const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx ]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1);
+ const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
+ const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx ]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5);
+ const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
+ const FLOAT_TYPE smin = FLOAT_TYPE(
+ FLOAT_TYPE(data_b[b_offset + y1_idx]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7
+ + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7
+ );
+
+ tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
+#endif
+ }
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
+ if (tid < s) {
+ tmp[tid] += tmp[tid + s];
+ }
+ barrier();
+ }
+ if (tid == 0) {
+ data_d[d_offset + row] = D_TYPE(tmp[0]);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp
new file mode 100644
index 00000000..9be3645b
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp
@@ -0,0 +1,111 @@
+#version 450
+
+#include "mul_mat_vec_base.comp"
+
+layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+
+shared FLOAT_TYPE tmp[32];
+
+void main() {
+ const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
+
+ uint a_offset, b_offset, d_offset;
+ get_offsets(a_offset, b_offset, d_offset);
+
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+ const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
+
+ const uint tid = gl_LocalInvocationID.x/2; // 0...31 or 0...16
+ const uint ix = gl_LocalInvocationID.x%2; // 0 or 0, 1
+
+ const uint il = tid/4; // 0...3
+ const uint ir = tid - 4*il; // 0...7 or 0...3
+
+ const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+ const uint v_in = il % 2;
+
+ const uint l0 = 4*ir + 2*v_in; // 0...15
+ const uint q_offset = 32*v_im + l0;
+ const uint y_offset = 64*v_im + l0;
+
+ const uint8_t hm1 = uint8_t(1 << (2*v_im));
+ const uint8_t hm2 = uint8_t(hm1 << 4);
+
+ tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
+
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
+ const uint y1_idx = i * QUANT_K + y_offset;
+ const uint y2_idx = y1_idx + 128;
+
+ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
+
+ const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
+ const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
+ const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
+ const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
+ const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
+ const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
+ const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
+ const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
+
+ const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
+ const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
+ const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] & 0xf);
+ const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] & 0xf);
+ const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
+ const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
+ const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] >> 4);
+ const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] >> 4);
+ const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
+ const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
+ const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] & 0xf);
+ const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] & 0xf);
+ const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
+ const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
+ const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] >> 4);
+ const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4);
+
+ const FLOAT_TYPE sx = FLOAT_TYPE(
+ FLOAT_TYPE(data_b[b_offset + y1_idx ]) * (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0))
+ + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0))
+ + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) * (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0))
+ + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0))
+ );
+ const FLOAT_TYPE sy = FLOAT_TYPE(
+ FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0))
+ + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0))
+ + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) * (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0))
+ + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0))
+ );
+ const FLOAT_TYPE sz = FLOAT_TYPE(
+ FLOAT_TYPE(data_b[b_offset + y2_idx ]) * (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0))
+ + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0))
+ + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) * (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0))
+ + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0))
+ );
+ const FLOAT_TYPE sw = FLOAT_TYPE(
+ FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0))
+ + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0))
+ + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) * (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0))
+ + FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0))
+ );
+ const FLOAT_TYPE smin = FLOAT_TYPE(
+ (FLOAT_TYPE(data_b[b_offset + y1_idx]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17])) * sc2 + (FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49])) * sc3
+ + (FLOAT_TYPE(data_b[b_offset + y2_idx]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17])) * sc6 + (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7
+ );
+ tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
+ }
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
+ if (tid < s) {
+ tmp[tid] += tmp[tid + s];
+ }
+ barrier();
+ }
+ if (tid == 0) {
+ data_d[d_offset + row] = D_TYPE(tmp[0]);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp
new file mode 100644
index 00000000..d610cf03
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp
@@ -0,0 +1,79 @@
+#version 450
+
+#include "mul_mat_vec_base.comp"
+
+layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
+
+shared FLOAT_TYPE tmp[32];
+
+void main() {
+ const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
+
+ uint a_offset, b_offset, d_offset;
+ get_offsets(a_offset, b_offset, d_offset);
+
+ const uint num_blocks_per_row = p.ncols / QUANT_K;
+ const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
+
+ const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
+ const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
+
+ const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
+
+ const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
+ const uint v_in = tid - step*v_im; // 0...15 or 0...7
+
+#if K_QUANTS_PER_ITERATION == 1
+ const uint l0 = v_in; // 0...15
+ const uint is = 0;
+#else
+ const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
+ const uint is = v_in / 4;
+#endif
+
+ const uint ql_offset = 64*v_im + l0;
+ const uint qh_offset = 32*v_im + l0;
+ const uint s_offset = 8*v_im + is;
+ const uint y_offset = 128*v_im + l0;
+
+ tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
+
+ [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+ const uint y_idx = i * QUANT_K + y_offset;
+
+ const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
+
+#if K_QUANTS_PER_ITERATION == 1
+ FLOAT_TYPE sum = FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32)
+ + FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32);
+ tmp[16 * ix + tid] += sum;
+#else
+ FLOAT_TYPE sum = FLOAT_TYPE(0.0);
+ [[unroll]] for (int l = 0; l < 4; ++l) {
+ sum += FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32)
+ + FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32);
+ }
+ tmp[16 * ix + tid] += sum;
+#endif
+ }
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
+ if (tid < s) {
+ tmp[tid] += tmp[tid + s];
+ }
+ barrier();
+ }
+ if (tid == 0) {
+ data_d[d_offset + row] = D_TYPE(tmp[0]);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/mul_mm.comp b/ggml/src/vulkan-shaders/mul_mm.comp
new file mode 100644
index 00000000..5fe9d524
--- /dev/null
+++ b/ggml/src/vulkan-shaders/mul_mm.comp
@@ -0,0 +1,507 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+
+#ifdef FLOAT16
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#endif
+
+#ifdef MUL_MAT_ID
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+#endif
+
+#include "types.comp"
+
+#ifndef LOAD_VEC_A
+#define LOAD_VEC_A 1
+#endif
+#ifndef LOAD_VEC_B
+#define LOAD_VEC_B 1
+#endif
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+
+#ifdef MUL_MAT_ID
+layout (binding = 3) readonly buffer IDS {int data_ids[];};
+#endif
+
+layout (push_constant) uniform parameter
+{
+ uint M;
+ uint N;
+ uint K;
+ uint stride_a;
+ uint stride_b;
+ uint stride_d;
+
+ uint batch_stride_a;
+ uint batch_stride_b;
+ uint batch_stride_d;
+
+#ifdef MUL_MAT_ID
+ uint nei0;
+ uint nei1;
+ uint nbi1;
+ uint ne11;
+#else
+ uint k_split;
+ uint ne02;
+ uint ne12;
+ uint broadcast2;
+ uint broadcast3;
+#endif
+} p;
+
+layout (constant_id = 1) const uint BM = 64;
+layout (constant_id = 2) const uint BN = 64;
+layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
+layout (constant_id = 4) const uint WM = 32;
+layout (constant_id = 5) const uint WN = 32;
+layout (constant_id = 6) const uint WMITER = 2;
+layout (constant_id = 7) const uint TM = 4;
+layout (constant_id = 8) const uint TN = 2;
+layout (constant_id = 9) const uint WARP = 32;
+
+shared FLOAT_TYPE buf_a[BM * (BK+1)];
+shared FLOAT_TYPE buf_b[BN * (BK+1)];
+
+#ifdef MUL_MAT_ID
+shared u16vec2 row_ids[3072];
+#endif
+
+void main() {
+#ifdef MUL_MAT_ID
+ const uint expert_idx = gl_GlobalInvocationID.z;
+#else
+ const uint batch_idx = gl_GlobalInvocationID.z;
+
+ const uint i13 = batch_idx / p.ne12;
+ const uint i12 = batch_idx % p.ne12;
+
+ const uint i03 = i13 / p.broadcast3;
+ const uint i02 = i12 / p.broadcast2;
+
+ const uint batch_idx_a = i03 * p.ne02 + i02;
+#endif
+
+ const uint blocks_m = (p.M + BM - 1) / BM;
+ const uint ir = gl_WorkGroupID.x % blocks_m;
+ const uint ik = gl_WorkGroupID.x / blocks_m;
+ const uint ic = gl_WorkGroupID.y;
+
+ const uint warp_i = gl_LocalInvocationID.x / WARP;
+ const uint warp_r = warp_i % (BM / WM);
+ const uint warp_c = warp_i / (BM / WM);
+
+ const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
+ const uint WSUBM = WM / WMITER;
+ const uint WSUBN = WN / WNITER;
+
+ const uint tiw = gl_LocalInvocationID.x % WARP;
+ const uint tiwr = tiw % (WSUBM / TM);
+ const uint tiwc = tiw / (WSUBM / TM);
+
+ const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
+ const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
+ const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
+ const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
+
+ const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
+ const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
+
+#ifdef MUL_MAT_ID
+ uint _ne1 = 0;
+ for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
+ for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
+ if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
+ row_ids[_ne1] = u16vec2(ii0, ii1);
+ _ne1++;
+ }
+ }
+ }
+
+ barrier();
+
+ // Workgroup has no work
+ if (ic * BN >= _ne1) return;
+#endif
+
+#ifdef MUL_MAT_ID
+ const uint start_k = 0;
+ const uint end_k = p.K;
+#else
+ const uint start_k = ik * p.k_split;
+ const uint end_k = min(p.K, (ik + 1) * p.k_split);
+#endif
+
+ uint pos_a = (
+#ifdef MUL_MAT_ID
+ expert_idx * p.batch_stride_a +
+#else
+ batch_idx_a * p.batch_stride_a +
+#endif
+ ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
+#ifdef MUL_MAT_ID
+ uint pos_b = 0;
+#else
+ uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
+#endif
+
+ float sums[WMITER * TM * WNITER * TN];
+ FLOAT_TYPE cache_a[WMITER * TM];
+ FLOAT_TYPE cache_b[WNITER * TN];
+
+ [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
+ sums[i] = 0.0f;
+ }
+
+ [[unroll]] for (uint block = start_k; block < end_k; block += BK) {
+ [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
+
+#if defined(DATA_A_F32) || defined(DATA_A_F16)
+#if LOAD_VEC_A == 8
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+ buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
+ buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
+ buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
+ buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
+ buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
+ buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
+ buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
+#elif LOAD_VEC_A == 4
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+ buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
+ buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
+ buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
+#else
+ if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
+ buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
+ } else {
+ buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
+ }
+#endif
+#elif defined(DATA_A_Q4_0)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+
+ const uint ib = idx / 16;
+ const uint iqs = idx & 0xF;
+
+ const float d = float(data_a[ib].d);
+ const uint vui = uint(data_a[ib].qs[iqs]);
+ const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+#elif defined(DATA_A_Q4_1)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+
+ const uint ib = idx / 16;
+ const uint iqs = idx & 0xF;
+
+ const float d = float(data_a[ib].d);
+ const float m = float(data_a[ib].m);
+ const uint vui = uint(data_a[ib].qs[iqs]);
+ const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+#elif defined(DATA_A_Q5_0)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+
+ const uint ib = idx / 16;
+ const uint iqs = idx & 0xF;
+
+ const float d = float(data_a[ib].d);
+ const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
+ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
+ const uint vui = uint(data_a[ib].qs[iqs]);
+ const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+#elif defined(DATA_A_Q5_1)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+
+ const uint ib = idx / 16;
+ const uint iqs = idx & 0xF;
+
+ const float d = float(data_a[ib].d);
+ const float m = float(data_a[ib].m);
+ const uint uint_qh = data_a[ib].qh;
+ const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
+ const uint vui = uint(data_a[ib].qs[iqs]);
+ const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+#elif defined(DATA_A_Q8_0)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+
+ const uint ib = idx / 16;
+ const uint iqs = (idx & 0xF) * 2;
+
+ const float d = float(data_a[ib].d);
+ const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+#elif defined(DATA_A_Q2_K)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+
+ const uint ib = idx / 128; // 2 values per idx
+ const uint iqs = idx % 128; // 0..127
+
+ const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
+ const uint scalesi = iqs / 8; // 0..15
+ const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
+
+ const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
+ const uint scales = data_a[ib].scales[scalesi];
+ const vec2 d = vec2(data_a[ib].d);
+
+ const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+#elif defined(DATA_A_Q3_K)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+
+ const uint ib = idx / 128; // 2 values per idx
+ const uint iqs = idx % 128; // 0..127
+
+ const uint n = iqs / 64; // 0,1
+ const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
+ const uint hmi = (iqs % 16) * 2; // 0,2,4..30
+ const uint j = (iqs % 64) / 4; // 0..3
+ const uint is = iqs / 8; // 0..15
+ const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
+ const uint qsshift = halfsplit * 2; // 0,2,4,6
+ const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
+
+ const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) :
+ is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) :
+ is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) :
+ (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4));
+ const float dl = float(data_a[ib].d) * float(us - 32);
+
+ buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
+ buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
+#elif defined(DATA_A_Q4_K)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+
+ const uint ib = idx / 128; // 2 values per idx
+ const uint iqs = idx % 128; // 0..127
+
+ const uint n = iqs / 32; // 0,1,2,3
+ const uint b = (iqs % 32) / 16; // 0,1
+ const uint is = 2 * n + b; // 0..7
+ const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
+
+ const vec2 loadd = vec2(data_a[ib].d);
+
+ uint8_t sc;
+ uint8_t mbyte;
+ if (is < 4) {
+ sc = uint8_t(data_a[ib].scales[is ] & 63);
+ mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
+ } else {
+ sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
+ mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
+ }
+ const float d = loadd.x * sc;
+ const float m = loadd.y * mbyte;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) - m);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) - m);
+#elif defined(DATA_A_Q5_K)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+
+ const uint ib = idx / 128; // 2 values per idx
+ const uint iqs = idx % 128; // 0..127
+
+ const uint n = iqs / 32; // 0,1,2,3
+ const uint b = (iqs % 32) / 16; // 0,1
+ const uint is = 2 * n + b; // 0..7
+ const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
+ const uint qhi = (iqs % 16) * 2; // 0,2,4..30
+
+ const uint8_t hm = uint8_t(1 << (iqs / 16));
+
+ const vec2 loadd = vec2(data_a[ib].d);
+
+ uint8_t sc;
+ uint8_t mbyte;
+ if (is < 4) {
+ sc = uint8_t(data_a[ib].scales[is ] & 63);
+ mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
+ } else {
+ sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
+ mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
+ }
+ const float d = loadd.x * sc;
+ const float m = loadd.y * mbyte;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0)) - m);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0)) - m);
+#elif defined(DATA_A_Q6_K)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+
+ const uint ib = idx / 128; // 2 values per idx
+ const uint iqs = idx % 128; // 0..127
+
+ const uint n = iqs / 64; // 0,1
+ const uint b = (iqs % 64) / 32; // 0,1
+ const uint is_b = (iqs % 16) / 8; // 0,1
+ const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
+ const uint is = 8 * n + qhshift + is_b; // 0..15
+ const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
+ const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
+
+ const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
+
+ buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
+ buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
+#elif defined(DATA_A_IQ4_NL)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+
+ const uint ib = idx / 16;
+ const uint iqs = idx & 0xF;
+
+ const float d = float(data_a[ib].d);
+ const uint vui = uint(data_a[ib].qs[iqs]);
+ const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
+
+ buf_a[buf_idx ] = FLOAT_TYPE(v.x);
+ buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
+#endif
+ }
+ [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
+#if LOAD_VEC_B == 8
+#ifdef MUL_MAT_ID
+ const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
+ const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
+#else
+ const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
+#endif
+ const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
+ buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
+ buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
+ buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
+ buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w);
+ buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x);
+ buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y);
+ buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z);
+ buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
+#elif LOAD_VEC_B == 4
+#ifdef MUL_MAT_ID
+ const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
+ const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
+#else
+ const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
+#endif
+ const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
+ buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
+ buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
+ buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
+ buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
+#elif !MUL_MAT_ID
+ if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
+ buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
+ } else {
+ buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
+ }
+#else
+ const uint row_i = ic * BN + loadc_b + l;
+ if (row_i < _ne1) {
+ const u16vec2 row_idx = row_ids[row_i];
+ buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
+ } else {
+ buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
+ }
+#endif
+ }
+
+ barrier();
+
+ pos_a += BK / LOAD_VEC_A;
+ pos_b += BK / LOAD_VEC_B;
+
+ for (uint i = 0; i < BK; i++) {
+ // Load from shared into cache
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint j = 0; j < TM; j++) {
+ cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
+ }
+ }
+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
+ [[unroll]] for (uint j = 0; j < TN; j++) {
+ cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
+ }
+ }
+
+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint cc = 0; cc < TN; cc++) {
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
+ }
+ }
+ }
+ }
+ }
+
+ barrier();
+ }
+
+ const uint dr = ir * BM + warp_r * WM;
+ const uint dc = ic * BN + warp_c * WN;
+
+#ifndef MUL_MAT_ID
+ const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+#endif
+
+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+
+ const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
+ const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
+ [[unroll]] for (uint cc = 0; cc < TN; cc++) {
+#ifdef MUL_MAT_ID
+ const uint row_i = dc_warp + cc;
+ if (row_i >= _ne1) break;
+
+ const u16vec2 row_idx = row_ids[row_i];
+#endif
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+#ifdef MUL_MAT_ID
+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+#else
+ if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
+ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+ }
+#endif
+ }
+ }
+ }
+ }
+}
diff --git a/ggml/src/vulkan-shaders/norm.comp b/ggml/src/vulkan-shaders/norm.comp
new file mode 100644
index 00000000..803dbdcb
--- /dev/null
+++ b/ggml/src/vulkan-shaders/norm.comp
@@ -0,0 +1,44 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+#define BLOCK_SIZE 512
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+shared vec2 sum[BLOCK_SIZE];
+
+void main() {
+ const uint row = gl_WorkGroupID.x;
+ const uint tid = gl_LocalInvocationID.x;
+
+ sum[tid] = vec2(0.0f, 0.0f);
+
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+ const float xi = float(data_a[row*p.KX + col]);
+ sum[tid].x += xi;
+ sum[tid].y += xi * xi;
+ }
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ sum[tid] += sum[tid + s];
+ }
+ barrier();
+ }
+
+ const float mean = sum[0].x / p.KX;
+ const float var = sum[0].y / p.KX - mean * mean;
+ const float inv_std = inversesqrt(var + p.param1);
+
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+ data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/relu.comp b/ggml/src/vulkan-shaders/relu.comp
new file mode 100644
index 00000000..7e5baa5b
--- /dev/null
+++ b/ggml/src/vulkan-shaders/relu.comp
@@ -0,0 +1,21 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ data_d[i] = max(float(data_a[i]), 0);
+}
diff --git a/ggml/src/vulkan-shaders/rms_norm.comp b/ggml/src/vulkan-shaders/rms_norm.comp
new file mode 100644
index 00000000..cfd08d34
--- /dev/null
+++ b/ggml/src/vulkan-shaders/rms_norm.comp
@@ -0,0 +1,42 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+#define BLOCK_SIZE 512
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+shared FLOAT_TYPE sum[BLOCK_SIZE];
+
+void main() {
+ const uint row = gl_WorkGroupID.x;
+ const uint tid = gl_LocalInvocationID.x;
+
+ sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
+
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
+ sum[tid] += xi * xi;
+ }
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ sum[tid] += sum[tid + s];
+ }
+ barrier();
+ }
+
+ const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX);
+ const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
+
+ [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
+ data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/rope_head.comp b/ggml/src/vulkan-shaders/rope_head.comp
new file mode 100644
index 00000000..ea895422
--- /dev/null
+++ b/ggml/src/vulkan-shaders/rope_head.comp
@@ -0,0 +1,44 @@
+#include "types.comp"
+
+#extension GL_EXT_shader_16bit_storage : require
+
+layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer Y {int data_pos[];};
+layout (binding = 2) readonly buffer Z {float data_ff[];};
+layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
+
+layout (push_constant) uniform parameter {
+ uint ncols;
+ uint n_dims;
+ float freq_scale;
+ uint p_delta_rows;
+ float freq_base;
+ float ext_factor;
+ float attn_factor;
+ float corr_dims[2];
+ float theta_scale;
+ uint has_ff;
+} p;
+
+float rope_yarn_ramp(const float low, const float high, const uint i0) {
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
+ return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) {
+ float mscale = p.attn_factor;
+ // Get n-d rotational scaling corrected for extrapolation
+ float theta_interp = p.freq_scale * theta_extrap;
+ float theta = theta_interp;
+ if (p.ext_factor != 0.0f) {
+ float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+ // Get n-d magnitude scaling corrected for interpolation
+ mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
+ }
+ cos_theta = cos(theta) * mscale;
+ sin_theta = sin(theta) * mscale;
+}
diff --git a/ggml/src/vulkan-shaders/rope_neox.comp b/ggml/src/vulkan-shaders/rope_neox.comp
new file mode 100644
index 00000000..83b46b69
--- /dev/null
+++ b/ggml/src/vulkan-shaders/rope_neox.comp
@@ -0,0 +1,37 @@
+#version 450
+
+#include "rope_head.comp"
+
+void main() {
+ const uint col = gl_GlobalInvocationID.y * 2;
+ const uint row = gl_GlobalInvocationID.x;
+
+ if (col >= p.ncols) {
+ return;
+ }
+
+ if (col >= p.n_dims) {
+ const uint i = row*p.ncols + col;
+
+ data_d[i + 0] = data_a[i + 0];
+ data_d[i + 1] = data_a[i + 1];
+
+ return;
+ }
+
+ const uint i = row*p.ncols + col/2;
+ const uint i2 = row/p.p_delta_rows;
+
+ const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
+
+ const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
+
+ float cos_theta, sin_theta;
+ rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
+
+ const float x0 = float(data_a[i + 0]);
+ const float x1 = float(data_a[i + p.n_dims/2]);
+
+ data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
+ data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
+}
diff --git a/ggml/src/vulkan-shaders/rope_norm.comp b/ggml/src/vulkan-shaders/rope_norm.comp
new file mode 100644
index 00000000..e416ad93
--- /dev/null
+++ b/ggml/src/vulkan-shaders/rope_norm.comp
@@ -0,0 +1,37 @@
+#version 450
+
+#include "rope_head.comp"
+
+void main() {
+ const uint col = gl_GlobalInvocationID.y * 2;
+ const uint row = gl_GlobalInvocationID.x;
+
+ if (col >= p.ncols) {
+ return;
+ }
+
+ if (col >= p.n_dims) {
+ const uint i = row*p.ncols + col;
+
+ data_d[i + 0] = data_a[i + 0];
+ data_d[i + 1] = data_a[i + 1];
+
+ return;
+ }
+
+ const uint i = row*p.ncols + col;
+ const uint i2 = row/p.p_delta_rows;
+
+ const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
+
+ const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
+
+ float cos_theta, sin_theta;
+ rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
+
+ const float x0 = float(data_a[i + 0]);
+ const float x1 = float(data_a[i + 1]);
+
+ data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
+ data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
+}
diff --git a/ggml/src/vulkan-shaders/scale.comp b/ggml/src/vulkan-shaders/scale.comp
new file mode 100644
index 00000000..510cb723
--- /dev/null
+++ b/ggml/src/vulkan-shaders/scale.comp
@@ -0,0 +1,12 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+void main() {
+ if (gl_GlobalInvocationID.x >= p.ne) {
+ return;
+ }
+
+ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]) * FLOAT_TYPE(p.param1));
+}
diff --git a/ggml/src/vulkan-shaders/silu.comp b/ggml/src/vulkan-shaders/silu.comp
new file mode 100644
index 00000000..15920f06
--- /dev/null
+++ b/ggml/src/vulkan-shaders/silu.comp
@@ -0,0 +1,22 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ const float xi = float(data_a[i]);
+ data_d[i] = D_TYPE(xi / (1.0f + exp(-xi)));
+}
diff --git a/ggml/src/vulkan-shaders/soft_max.comp b/ggml/src/vulkan-shaders/soft_max.comp
new file mode 100644
index 00000000..1b8419c7
--- /dev/null
+++ b/ggml/src/vulkan-shaders/soft_max.comp
@@ -0,0 +1,106 @@
+#version 450
+
+#extension GL_EXT_shader_16bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+ uint KX;
+ uint KY;
+ float scale;
+ float max_bias;
+ float m0;
+ float m1;
+ uint n_head_log2;
+} p;
+
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+#define BLOCK_SIZE 512
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
+layout (binding = 2) buffer D {D_TYPE data_d[];};
+
+shared FLOAT_TYPE vals[BLOCK_SIZE];
+
+void main() {
+ const uint tid = gl_LocalInvocationID.x;
+ const uint rowx = gl_WorkGroupID.x;
+ const uint rowy = rowx % p.KY;
+
+ float slope = 1.0f;
+
+ // ALiBi
+ if (p.max_bias > 0.0f) {
+ const uint h = rowx/p.KY; // head index
+
+ const float base = h < p.n_head_log2 ? p.m0 : p.m1;
+ const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
+
+ slope = pow(base, exp);
+ }
+
+ // Find max
+ FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
+
+ [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
+ const uint col = col0 + tid;
+
+ if (col >= p.KX) {
+ break;
+ }
+
+ max_val = max(max_val, FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)));
+ }
+ vals[tid] = max_val;
+
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ vals[tid] = max(vals[tid], vals[tid + s]);
+ }
+ barrier();
+ }
+
+ max_val = vals[0];
+ barrier();
+
+ // Sum up values
+ vals[tid] = FLOAT_TYPE(0.0f);
+
+ [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
+ const uint col = col0 + tid;
+
+ if (col >= p.KX) {
+ break;
+ }
+
+ const uint i = rowx * p.KX + col;
+ const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
+ vals[tid] += val;
+ data_d[i] = D_TYPE(val);
+ }
+
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ vals[tid] += vals[tid + s];
+ }
+ barrier();
+ }
+
+ const D_TYPE divisor = D_TYPE(vals[0]);
+
+ [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
+ const uint col = col0 + tid;
+
+ if (col >= p.KX) {
+ break;
+ }
+
+ data_d[rowx*p.KX + col] /= divisor;
+ }
+}
diff --git a/ggml/src/vulkan-shaders/square.comp b/ggml/src/vulkan-shaders/square.comp
new file mode 100644
index 00000000..8dd19333
--- /dev/null
+++ b/ggml/src/vulkan-shaders/square.comp
@@ -0,0 +1,13 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+void main() {
+ if (gl_GlobalInvocationID.x >= p.ne) {
+ return;
+ }
+
+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(gl_GlobalInvocationID.x)]);
+ data_d[p.d_offset + dst_idx(gl_GlobalInvocationID.x)] = D_TYPE(val * val);
+}
diff --git a/ggml/src/vulkan-shaders/sum_rows.comp b/ggml/src/vulkan-shaders/sum_rows.comp
new file mode 100644
index 00000000..ce2f1e2f
--- /dev/null
+++ b/ggml/src/vulkan-shaders/sum_rows.comp
@@ -0,0 +1,37 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+
+shared FLOAT_TYPE tmp[BLOCK_SIZE];
+
+void main() {
+ const uint row = gl_WorkGroupID.x;
+ const uint col = gl_LocalInvocationID.x;
+
+ tmp[col] = FLOAT_TYPE(0.0f);
+
+ for (uint i = col; i < p.KX; i += BLOCK_SIZE) {
+ tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]);
+ }
+
+ barrier();
+ [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {
+ if (col < s) {
+ tmp[col] += tmp[col + s];
+ }
+ barrier();
+ }
+
+ if (col == 0) {
+ data_d[row] = D_TYPE(tmp[0]);
+ }
+}
diff --git a/ggml/src/vulkan-shaders/types.comp b/ggml/src/vulkan-shaders/types.comp
new file mode 100644
index 00000000..d24c172c
--- /dev/null
+++ b/ggml/src/vulkan-shaders/types.comp
@@ -0,0 +1,200 @@
+#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
+#endif
+
+#if defined(DATA_A_F32)
+#define QUANT_K 1
+#define QUANT_R 1
+
+#ifndef LOAD_VEC_A
+#define A_TYPE float
+#elif LOAD_VEC_A == 4
+#define A_TYPE vec4
+#elif LOAD_VEC_A == 8
+#define A_TYPE mat2x4
+#endif
+#endif
+
+#if defined(DATA_A_F16)
+#define QUANT_K 1
+#define QUANT_R 1
+
+#ifndef LOAD_VEC_A
+#define A_TYPE float16_t
+#elif LOAD_VEC_A == 4
+#define A_TYPE f16vec4
+#elif LOAD_VEC_A == 8
+#define A_TYPE f16mat2x4
+#endif
+#endif
+
+#if defined(DATA_A_Q4_0)
+#extension GL_EXT_shader_16bit_storage : require
+#define QUANT_K 32
+#define QUANT_R 2
+
+struct block_q4_0
+{
+ float16_t d;
+ uint8_t qs[16];
+};
+
+#define A_TYPE block_q4_0
+#endif
+
+#if defined(DATA_A_Q4_1)
+#extension GL_EXT_shader_16bit_storage : require
+#define QUANT_K 32
+#define QUANT_R 2
+
+struct block_q4_1
+{
+ float16_t d;
+ float16_t m;
+ uint8_t qs[16];
+};
+
+#define A_TYPE block_q4_1
+#endif
+
+#if defined(DATA_A_Q5_0)
+#extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+#define QUANT_K 32
+#define QUANT_R 2
+
+struct block_q5_0
+{
+ float16_t d;
+ uint16_t qh[2];
+ uint8_t qs[16];
+};
+
+#define A_TYPE block_q5_0
+#endif
+
+#if defined(DATA_A_Q5_1)
+#extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+#define QUANT_K 32
+#define QUANT_R 2
+
+struct block_q5_1
+{
+ float16_t d;
+ float16_t m;
+ uint qh;
+ uint8_t qs[16];
+};
+
+#define A_TYPE block_q5_1
+#endif
+
+#if defined(DATA_A_Q8_0)
+#extension GL_EXT_shader_16bit_storage : require
+#define QUANT_K 32
+#define QUANT_R 1
+
+struct block_q8_0
+{
+ float16_t d;
+ int8_t qs[32];
+};
+
+#define A_TYPE block_q8_0
+#endif
+
+// K-quants
+#if defined(DATA_A_Q2_K)
+#extension GL_EXT_shader_16bit_storage : require
+#define QUANT_K 256
+
+struct block_q2_K
+{
+ uint8_t scales[QUANT_K/16];
+ uint8_t qs[QUANT_K/4];
+ f16vec2 d;
+};
+
+#define A_TYPE block_q2_K
+#endif
+
+#if defined(DATA_A_Q3_K)
+#extension GL_EXT_shader_16bit_storage : require
+#define QUANT_K 256
+
+struct block_q3_K
+{
+ uint8_t hmask[QUANT_K/8];
+ uint8_t qs[QUANT_K/4];
+ uint8_t scales[12];
+ float16_t d;
+};
+
+#define A_TYPE block_q3_K
+#endif
+
+#if defined(DATA_A_Q4_K)
+#extension GL_EXT_shader_16bit_storage : require
+#define QUANT_K 256
+
+struct block_q4_K
+{
+ f16vec2 d;
+ uint8_t scales[3*QUANT_K/64];
+ uint8_t qs[QUANT_K/2];
+};
+
+#define A_TYPE block_q4_K
+#endif
+
+#if defined(DATA_A_Q5_K)
+#extension GL_EXT_shader_16bit_storage : require
+#define QUANT_K 256
+
+struct block_q5_K
+{
+ f16vec2 d;
+ uint8_t scales[12];
+ uint8_t qh[QUANT_K/8];
+ uint8_t qs[QUANT_K/2];
+};
+
+#define A_TYPE block_q5_K
+#endif
+
+#if defined(DATA_A_Q6_K)
+#extension GL_EXT_shader_16bit_storage : require
+#define QUANT_K 256
+
+struct block_q6_K
+{
+ uint8_t ql[QUANT_K/2];
+ uint8_t qh[QUANT_K/4];
+ int8_t scales[QUANT_K/16];
+ float16_t d;
+};
+
+#define A_TYPE block_q6_K
+#endif
+
+// IQuants
+
+#if defined(DATA_A_IQ4_NL)
+#extension GL_EXT_shader_16bit_storage : require
+#define QUANT_K 32
+#define QUANT_R 2
+
+struct block_iq4_nl
+{
+ float16_t d;
+ uint8_t qs[QUANT_K/2];
+};
+
+#define A_TYPE block_iq4_nl
+
+const int8_t kvalues_iq4nl[16] = {
+ int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
+ int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
+};
+#endif
diff --git a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
new file mode 100644
index 00000000..c5be3754
--- /dev/null
+++ b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -0,0 +1,525 @@
+
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <string>
+#include <stdexcept>
+#include <array>
+#include <vector>
+#include <map>
+#include <thread>
+#include <mutex>
+#include <future>
+#include <queue>
+#include <condition_variable>
+#include <cstdio>
+#include <cstring>
+#include <cstdlib>
+#include <sys/stat.h>
+#include <sys/types.h>
+
+#ifdef _WIN32
+ #include <windows.h>
+ #include <direct.h> // For _mkdir on Windows
+#else
+ #include <unistd.h>
+ #include <sys/wait.h>
+ #include <fcntl.h>
+#endif
+
+#define ASYNCIO_CONCURRENCY 64
+
+std::mutex lock;
+std::vector<std::pair<std::string, std::string>> shader_fnames;
+
+std::string GLSLC = "glslc";
+std::string input_dir = "vulkan-shaders";
+std::string output_dir = "/tmp";
+std::string target_hpp = "ggml-vulkan-shaders.hpp";
+std::string target_cpp = "ggml-vulkan-shaders.cpp";
+bool no_clean = false;
+
+const std::vector<std::string> type_names = {
+ "f32",
+ "f16",
+ "q4_0",
+ "q4_1",
+ "q5_0",
+ "q5_1",
+ "q8_0",
+ "q2_k",
+ "q3_k",
+ "q4_k",
+ "q5_k",
+ "q6_k",
+ "iq4_nl"
+};
+
+void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
+#ifdef _WIN32
+ HANDLE stdout_read, stdout_write;
+ HANDLE stderr_read, stderr_write;
+ SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE };
+
+ if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) ||
+ !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) {
+ throw std::runtime_error("Failed to create stdout pipe");
+ }
+
+ if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) ||
+ !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) {
+ throw std::runtime_error("Failed to create stderr pipe");
+ }
+
+ PROCESS_INFORMATION pi;
+ STARTUPINFOA si = { sizeof(STARTUPINFOA) };
+ si.dwFlags = STARTF_USESTDHANDLES;
+ si.hStdOutput = stdout_write;
+ si.hStdError = stderr_write;
+
+ std::vector<char> cmd(command.begin(), command.end());
+ cmd.push_back('\0');
+
+ if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
+ throw std::runtime_error("Failed to create process");
+ }
+
+ CloseHandle(stdout_write);
+ CloseHandle(stderr_write);
+
+ std::array<char, 128> buffer;
+ DWORD bytes_read;
+
+ while (ReadFile(stdout_read, buffer.data(), buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
+ stdout_str.append(buffer.data(), bytes_read);
+ }
+
+ while (ReadFile(stderr_read, buffer.data(), buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
+ stderr_str.append(buffer.data(), bytes_read);
+ }
+
+ CloseHandle(stdout_read);
+ CloseHandle(stderr_read);
+ WaitForSingleObject(pi.hProcess, INFINITE);
+ CloseHandle(pi.hProcess);
+ CloseHandle(pi.hThread);
+#else
+int stdout_pipe[2];
+ int stderr_pipe[2];
+
+ if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
+ throw std::runtime_error("Failed to create pipes");
+ }
+
+ pid_t pid = fork();
+ if (pid < 0) {
+ throw std::runtime_error("Failed to fork process");
+ }
+
+ if (pid == 0) {
+ close(stdout_pipe[0]);
+ close(stderr_pipe[0]);
+ dup2(stdout_pipe[1], STDOUT_FILENO);
+ dup2(stderr_pipe[1], STDERR_FILENO);
+ close(stdout_pipe[1]);
+ close(stderr_pipe[1]);
+ execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr);
+ _exit(EXIT_FAILURE);
+ } else {
+ close(stdout_pipe[1]);
+ close(stderr_pipe[1]);
+
+ std::array<char, 128> buffer;
+ ssize_t bytes_read;
+
+ while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) {
+ stdout_str.append(buffer.data(), bytes_read);
+ }
+
+ while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) {
+ stderr_str.append(buffer.data(), bytes_read);
+ }
+
+ close(stdout_pipe[0]);
+ close(stderr_pipe[0]);
+ waitpid(pid, nullptr, 0);
+ }
+#endif
+}
+
+bool directory_exists(const std::string& path) {
+ struct stat info;
+ if (stat(path.c_str(), &info) != 0) {
+ return false; // Path doesn't exist or can't be accessed
+ }
+ return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory
+}
+
+bool create_directory(const std::string& path) {
+#ifdef _WIN32
+ return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists
+#else
+ return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions
+#endif
+}
+
+std::string to_uppercase(const std::string& input) {
+ std::string result = input;
+ for (char& c : result) {
+ c = std::toupper(c);
+ }
+ return result;
+}
+
+bool string_ends_with(const std::string& str, const std::string& suffix) {
+ if (suffix.size() > str.size()) {
+ return false;
+ }
+ return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
+}
+
+#ifdef _WIN32
+ static const char path_separator = '\\';
+#else
+ static const char path_separator = '/';
+#endif
+
+std::string join_paths(const std::string& path1, const std::string& path2) {
+ return path1 + path_separator + path2;
+}
+
+std::string basename(const std::string &path) {
+ return path.substr(path.find_last_of("/\\") + 1);
+}
+
+void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) {
+ std::string name = _name + (fp16 ? "" : "_fp32");
+ std::string out_fname = join_paths(output_dir, name + ".spv");
+ std::string in_path = join_paths(input_dir, in_fname);
+
+ std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname};
+ for (const auto& define : defines) {
+ cmd.push_back("-D" + define.first + "=" + define.second);
+ }
+
+ std::string command;
+ for (const auto& part : cmd) {
+ command += part + " ";
+ }
+
+ std::string stdout_str, stderr_str;
+ try {
+ // std::cout << "Executing command: ";
+ // for (const auto& part : cmd) {
+ // std::cout << part << " ";
+ // }
+ // std::cout << std::endl;
+
+ execute_command(command, stdout_str, stderr_str);
+ if (!stderr_str.empty()) {
+ std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl;
+ return;
+ }
+
+ std::lock_guard<std::mutex> guard(lock);
+ shader_fnames.push_back(std::make_pair(name, out_fname));
+ } catch (const std::exception& e) {
+ std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
+ }
+}
+
+std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
+ std::map<std::string, std::string> result = a;
+ result.insert(b.begin(), b.end());
+ return result;
+}
+
+void matmul_shaders(std::vector<std::future<void>>& tasks, bool fp16, bool matmul_id) {
+ std::string load_vec = fp16 ? "8" : "4";
+ std::string aligned_b_type_f32 = fp16 ? "mat2x4" : "vec4";
+ std::string aligned_b_type_f16 = fp16 ? "f16mat2x4" : "f16vec4";
+
+ std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", fp16 ? "float16_t" : "float"}};
+ std::string shader_name = "matmul";
+
+ if (matmul_id) {
+ base_dict["MUL_MAT_ID"] = "1";
+ shader_name = "matmul_id";
+ }
+
+ if (fp16) {
+ base_dict["FLOAT16"] = "1";
+ }
+
+ // Shaders with f16 B_TYPE
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv(shader_name + "_f32_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
+ }));
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv(shader_name + "_f32_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv(shader_name + "_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
+ }));
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv(shader_name + "_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
+ }));
+
+ for (const auto& tname : type_names) {
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+ std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
+ }));
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
+ }));
+ }
+}
+
+void process_shaders(std::vector<std::future<void>>& tasks) {
+ std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
+ std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
+
+ for (const auto& fp16 : {false, true}) {
+ matmul_shaders(tasks, fp16, false);
+ matmul_shaders(tasks, fp16, true);
+ }
+
+ for (const auto& tname : type_names) {
+ // mul mat vec
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+ std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
+
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+ }));
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+ }));
+
+ // Dequant shaders
+ if (tname != "f16") {
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
+ }));
+ }
+
+ if (!string_ends_with(tname, "_k")) {
+ shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
+
+ if (tname == "f16") {
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
+ }));
+ } else {
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}});
+ }));
+ }
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("get_rows_" + tname + "_f32", shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
+ }));
+ }
+ }
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+
+ // Norms
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ }));
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
+ }));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+ }));
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ }));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [=] {
+ string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ }));
+}
+
+void write_output_files() {
+ FILE* hdr = fopen(target_hpp.c_str(), "w");
+ FILE* src = fopen(target_cpp.c_str(), "w");
+
+ fprintf(hdr, "#include <cstdint>\n\n");
+ fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
+
+ for (const auto& pair : shader_fnames) {
+ const std::string& name = pair.first;
+ const std::string& path = pair.second;
+ FILE* spv = fopen(path.c_str(), "rb");
+ if (!spv) {
+ std::cerr << "Error opening SPIR-V file: " << path << "\n";
+ continue;
+ }
+
+ fseek(spv, 0, SEEK_END);
+ size_t size = ftell(spv);
+ fseek(spv, 0, SEEK_SET);
+
+ std::vector<unsigned char> data(size);
+ size_t read_size = fread(data.data(), 1, size, spv);
+ fclose(spv);
+ if (read_size != size) {
+ std::cerr << "Error reading SPIR-V file: " << path << "\n";
+ continue;
+ }
+
+ fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size);
+ fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size);
+
+ fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size);
+ for (size_t i = 0; i < size; ++i) {
+ fprintf(src, "0x%02x,", data[i]);
+ if ((i + 1) % 12 == 0) fprintf(src, "\n");
+ }
+ fprintf(src, "\n};\n\n");
+
+ if (!no_clean) {
+ std::remove(path.c_str());
+ }
+ }
+
+ fclose(hdr);
+ fclose(src);
+}
+
+int main(int argc, char** argv) {
+ std::map<std::string, std::string> args;
+ for (int i = 1; i < argc; i += 2) {
+ if (i + 1 < argc) {
+ args[argv[i]] = argv[i + 1];
+ }
+ }
+
+ if (args.find("--glslc") != args.end()) {
+ GLSLC = args["--glslc"]; // Path to glslc
+ }
+ if (args.find("--input-dir") != args.end()) {
+ input_dir = args["--input-dir"]; // Directory containing shader sources
+ }
+ if (args.find("--output-dir") != args.end()) {
+ output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
+ }
+ if (args.find("--target-hpp") != args.end()) {
+ target_hpp = args["--target-hpp"]; // Path to generated header file
+ }
+ if (args.find("--target-cpp") != args.end()) {
+ target_cpp = args["--target-cpp"]; // Path to generated cpp file
+ }
+ if (args.find("--no-clean") != args.end()) {
+ no_clean = true; // Keep temporary SPIR-V files in output-dir after build
+ }
+
+ if (!directory_exists(input_dir)) {
+ std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl;
+ return EXIT_FAILURE;
+ }
+
+ if (!directory_exists(output_dir)) {
+ if (!create_directory(output_dir)) {
+ std::cerr << "Error creating output directory: " << output_dir << "\n";
+ return EXIT_FAILURE;
+ }
+ }
+
+ std::vector<std::future<void>> tasks;
+ process_shaders(tasks);
+
+ for (auto& task : tasks) {
+ task.get();
+ }
+
+ write_output_files();
+
+ return EXIT_SUCCESS;
+}