summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt12
-rw-r--r--Makefile18
-rw-r--r--Package.swift3
-rw-r--r--build.zig21
-rw-r--r--examples/quantize/quantize.cpp9
-rw-r--r--ggml-quants.c (renamed from k_quants.c)2248
-rw-r--r--ggml-quants.h (renamed from k_quants.h)103
-rw-r--r--ggml.c2301
-rw-r--r--ggml.h7
-rw-r--r--llama.cpp34
-rw-r--r--llama.h1
11 files changed, 2372 insertions, 2385 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d9fc8623..3659279e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -94,7 +94,6 @@ option(LLAMA_CLBLAST "llama: use CLBlast"
option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF)
option(LLAMA_MPI "llama: use MPI" OFF)
-option(LLAMA_K_QUANTS "llama: use k-quants" ON)
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
@@ -278,13 +277,8 @@ if (LLAMA_BLAS)
endif()
endif()
-if (LLAMA_K_QUANTS)
- set(GGML_HEADERS_EXTRA k_quants.h)
- set(GGML_SOURCES_EXTRA k_quants.c)
- add_compile_definitions(GGML_USE_K_QUANTS)
- if (LLAMA_QKK_64)
- add_compile_definitions(GGML_QKK_64)
- endif()
+if (LLAMA_QKK_64)
+ add_compile_definitions(GGML_QKK_64)
endif()
if (LLAMA_CUBLAS)
@@ -673,6 +667,8 @@ add_library(ggml OBJECT
ggml-alloc.h
ggml-backend.c
ggml-backend.h
+ ggml-quants.c
+ ggml-quants.h
${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA}
${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL}
${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}
diff --git a/Makefile b/Makefile
index 68069f9f..2cecc221 100644
--- a/Makefile
+++ b/Makefile
@@ -342,13 +342,9 @@ else
MK_CXXFLAGS += -march=rv64gcv -mabi=lp64d
endif
-ifndef LLAMA_NO_K_QUANTS
- MK_CPPFLAGS += -DGGML_USE_K_QUANTS
- OBJS += k_quants.o
ifdef LLAMA_QKK_64
MK_CPPFLAGS += -DGGML_QKK_64
endif
-endif
ifndef LLAMA_NO_ACCELERATE
# Mac OS - include Accelerate framework.
@@ -365,7 +361,7 @@ ifdef LLAMA_MPI
MK_CPPFLAGS += -DGGML_USE_MPI
MK_CFLAGS += -Wno-cast-qual
MK_CXXFLAGS += -Wno-cast-qual
- OBJS += ggml-mpi.o
+ OBJS += ggml-mpi.o
endif # LLAMA_MPI
ifdef LLAMA_OPENBLAS
@@ -382,7 +378,7 @@ endif # LLAMA_BLIS
ifdef LLAMA_CUBLAS
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
- OBJS += ggml-cuda.o
+ OBJS += ggml-cuda.o
NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
ifdef LLAMA_CUDA_NVCC
NVCC = $(LLAMA_CUDA_NVCC)
@@ -497,11 +493,6 @@ ggml-mpi.o: ggml-mpi.c ggml-mpi.h
$(CC) $(CFLAGS) -c $< -o $@
endif # LLAMA_MPI
-ifndef LLAMA_NO_K_QUANTS
-k_quants.o: k_quants.c k_quants.h
- $(CC) $(CFLAGS) -c $< -o $@
-endif # LLAMA_NO_K_QUANTS
-
# combine build flags with cmdline overrides
override CFLAGS := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CFLAGS) $(CFLAGS)
override CXXFLAGS := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CXXFLAGS) $(CXXFLAGS)
@@ -542,7 +533,10 @@ ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
$(CC) $(CFLAGS) -c $< -o $@
-OBJS += ggml-alloc.o ggml-backend.o
+ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
+ $(CC) $(CFLAGS) -c $< -o $@
+
+OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o
llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h
$(CXX) $(CXXFLAGS) -c $< -o $@
diff --git a/Package.swift b/Package.swift
index 4ab055b1..5b3bd72c 100644
--- a/Package.swift
+++ b/Package.swift
@@ -42,13 +42,12 @@ let package = Package(
"llama.cpp",
"ggml-alloc.c",
"ggml-backend.c",
- "k_quants.c",
+ "ggml-quants.c",
] + additionalSources,
resources: resources,
publicHeadersPath: "spm-headers",
cSettings: [
.unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),
- .define("GGML_USE_K_QUANTS"),
.define("GGML_USE_ACCELERATE")
// NOTE: NEW_LAPACK will required iOS version 16.4+
// We should consider add this in the future when we drop support for iOS 14
diff --git a/build.zig b/build.zig
index dcfa3dd6..9b58b74c 100644
--- a/build.zig
+++ b/build.zig
@@ -116,15 +116,10 @@ pub fn build(b: *std.build.Builder) !void {
var make = try Maker.init(b);
make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false;
- if (b.option(bool, "k-quants", "Enable K-quants, (default: true)") orelse true) {
- try make.addFlag("-DGGML_USE_K_QUANTS");
- const k_quants = make.obj("k_quants", "k_quants.c");
- try make.objs.append(k_quants);
- }
-
const ggml = make.obj("ggml", "ggml.c");
const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c");
const ggml_backend = make.obj("ggml-backend", "ggml-backend.c");
+ const ggml_quants = make.obj("ggml-quants", "ggml-quants.c");
const llama = make.obj("llama", "llama.cpp");
const common = make.obj("common", "common/common.cpp");
const console = make.obj("console", "common/console.cpp");
@@ -133,14 +128,14 @@ pub fn build(b: *std.build.Builder) !void {
const train = make.obj("train", "common/train.cpp");
const clip = make.obj("clip", "examples/llava/clip.cpp");
- _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, sampling, console, grammar_parser });
- _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common });
- _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common });
- _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common });
- _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, train });
- _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, train });
+ _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, sampling, console, grammar_parser });
+ _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common });
+ _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common });
+ _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common });
+ _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, train });
+ _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, train });
- const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, sampling, grammar_parser, clip });
+ const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, sampling, grammar_parser, clip });
if (server.target.isWindows()) {
server.linkSystemLibrary("ws2_32");
}
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index c7dd0d89..be0b2fe1 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -18,7 +18,6 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 3.90G, +0.1585 ppl @ LLaMA-v1-7B", },
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 4.33G, +0.0683 ppl @ LLaMA-v1-7B", },
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", },
-#ifdef GGML_USE_K_QUANTS
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
{ "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" },
{ "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", },
@@ -31,7 +30,6 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S, " 4.33G, +0.0400 ppl @ LLaMA-v1-7B", },
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", },
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, -0.0008 ppl @ LLaMA-v1-7B", },
-#endif
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", },
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "13.00G @ 7B", },
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },
@@ -70,13 +68,14 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
}
// usage:
-// ./quantize [--allow-requantize] [--leave-output-tensor] models/llama/ggml-model.gguf [models/llama/ggml-model-quant.gguf] type [nthreads]
+// ./quantize [--allow-requantize] [--leave-output-tensor] [--pure] models/llama/ggml-model.gguf [models/llama/ggml-model-quant.gguf] type [nthreads]
//
[[noreturn]]
static void usage(const char * executable) {
- printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable);
+ printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable);
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
+ printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
printf("\nAllowed quantization types:\n");
for (auto & it : QUANT_OPTIONS) {
if (it.name != "COPY") {
@@ -103,6 +102,8 @@ int main(int argc, char ** argv) {
params.quantize_output_tensor = false;
} else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) {
params.allow_requantize = true;
+ } else if (strcmp(argv[arg_idx], "--pure") == 0) {
+ params.pure = true;
} else {
usage(argv[0]);
}
diff --git a/k_quants.c b/ggml-quants.c
index 801941fb..fd4ee1be 100644
--- a/k_quants.c
+++ b/ggml-quants.c
@@ -1,9 +1,10 @@
-#include "k_quants.h"
+#include "ggml-quants.h"
#include "ggml.h"
#include <math.h>
#include <string.h>
#include <assert.h>
+#include <float.h>
#ifdef __ARM_NEON
@@ -65,6 +66,1024 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+// multiply int8_t, add results pairwise twice
+static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
+ // Get absolute values of x vectors
+ const __m128i ax = _mm_sign_epi8(x, x);
+ // Sign the values of the y vectors
+ const __m128i sy = _mm_sign_epi8(y, x);
+ // Perform multiplication and create 16-bit values
+ const __m128i dot = _mm_maddubs_epi16(ax, sy);
+ const __m128i ones = _mm_set1_epi16(1);
+ return _mm_madd_epi16(ones, dot);
+}
+
+#if __AVX__ || __AVX2__ || __AVX512F__
+// horizontally add 8 floats
+static inline float hsum_float_8(const __m256 x) {
+ __m128 res = _mm256_extractf128_ps(x, 1);
+ res = _mm_add_ps(res, _mm256_castps256_ps128(x));
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
+ return _mm_cvtss_f32(res);
+}
+
+// horizontally add 8 int32_t
+static inline int hsum_i32_8(const __m256i a) {
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+// horizontally add 4 int32_t
+static inline int hsum_i32_4(const __m128i a) {
+ const __m128i hi64 = _mm_unpackhi_epi64(a, a);
+ const __m128i sum64 = _mm_add_epi32(hi64, a);
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+#if defined(__AVX2__) || defined(__AVX512F__)
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+ uint32_t x32;
+ memcpy(&x32, x, sizeof(uint32_t));
+ const __m256i shuf_mask = _mm256_set_epi64x(
+ 0x0303030303030303, 0x0202020202020202,
+ 0x0101010101010101, 0x0000000000000000);
+ __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
+ const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
+ bytes = _mm256_or_si256(bytes, bit_mask);
+ return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+ const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
+ const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
+ const __m256i lowMask = _mm256_set1_epi8( 0xF );
+ return _mm256_and_si256(lowMask, bytes);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m256i x) {
+ const __m256i ones = _mm256_set1_epi16(1);
+ const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
+ return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+#if __AVXVNNI__
+ const __m256i zero = _mm256_setzero_si256();
+ const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
+ return _mm256_cvtepi32_ps(summed_pairs);
+#else
+ // Perform multiplication and create 16-bit values
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+ return sum_i16_pairs_float(dot);
+#endif
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+#if __AVXVNNIINT8__
+ const __m256i zero = _mm256_setzero_si256();
+ const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
+ return _mm256_cvtepi32_ps(summed_pairs);
+#else
+ // Get absolute values of x vectors
+ const __m256i ax = _mm256_sign_epi8(x, x);
+ // Sign the values of the y vectors
+ const __m256i sy = _mm256_sign_epi8(y, x);
+ return mul_sum_us8_pairs_float(ax, sy);
+#endif
+}
+
+static inline __m128i packNibbles( __m256i bytes )
+{
+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+#if __AVX512F__
+ const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
+ bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
+ return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
+#else
+ const __m256i lowByte = _mm256_set1_epi16( 0xFF );
+ __m256i high = _mm256_andnot_si256( lowByte, bytes );
+ __m256i low = _mm256_and_si256( lowByte, bytes );
+ high = _mm256_srli_epi16( high, 4 );
+ bytes = _mm256_or_si256( low, high );
+
+ // Compress uint16_t lanes into bytes
+ __m128i r0 = _mm256_castsi256_si128( bytes );
+ __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
+ return _mm_packus_epi16( r0, r1 );
+#endif
+}
+#elif defined(__AVX__)
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+ uint32_t x32;
+ memcpy(&x32, x, sizeof(uint32_t));
+ const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
+ const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
+ __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
+ __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
+ const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
+ bytesl = _mm_or_si128(bytesl, bit_mask);
+ bytesh = _mm_or_si128(bytesh, bit_mask);
+ bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
+ bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
+ return MM256_SET_M128I(bytesh, bytesl);
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+ // Load 16 bytes from memory
+ __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
+ __m128i tmph = _mm_srli_epi16(tmpl, 4);
+ const __m128i lowMask = _mm_set1_epi8(0xF);
+ tmpl = _mm_and_si128(lowMask, tmpl);
+ tmph = _mm_and_si128(lowMask, tmph);
+ return MM256_SET_M128I(tmph, tmpl);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
+ const __m128i ones = _mm_set1_epi16(1);
+ const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
+ const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
+ const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
+ return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+ const __m128i axl = _mm256_castsi256_si128(ax);
+ const __m128i axh = _mm256_extractf128_si256(ax, 1);
+ const __m128i syl = _mm256_castsi256_si128(sy);
+ const __m128i syh = _mm256_extractf128_si256(sy, 1);
+ // Perform multiplication and create 16-bit values
+ const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+ const __m128i doth = _mm_maddubs_epi16(axh, syh);
+ return sum_i16_pairs_float(doth, dotl);
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+ const __m128i xl = _mm256_castsi256_si128(x);
+ const __m128i xh = _mm256_extractf128_si256(x, 1);
+ const __m128i yl = _mm256_castsi256_si128(y);
+ const __m128i yh = _mm256_extractf128_si256(y, 1);
+ // Get absolute values of x vectors
+ const __m128i axl = _mm_sign_epi8(xl, xl);
+ const __m128i axh = _mm_sign_epi8(xh, xh);
+ // Sign the values of the y vectors
+ const __m128i syl = _mm_sign_epi8(yl, xl);
+ const __m128i syh = _mm_sign_epi8(yh, xh);
+ // Perform multiplication and create 16-bit values
+ const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+ const __m128i doth = _mm_maddubs_epi16(axh, syh);
+ return sum_i16_pairs_float(doth, dotl);
+}
+
+static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
+{
+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+ const __m128i lowByte = _mm_set1_epi16( 0xFF );
+ __m128i high = _mm_andnot_si128( lowByte, bytes1 );
+ __m128i low = _mm_and_si128( lowByte, bytes1 );
+ high = _mm_srli_epi16( high, 4 );
+ bytes1 = _mm_or_si128( low, high );
+ high = _mm_andnot_si128( lowByte, bytes2 );
+ low = _mm_and_si128( lowByte, bytes2 );
+ high = _mm_srli_epi16( high, 4 );
+ bytes2 = _mm_or_si128( low, high );
+
+ return _mm_packus_epi16( bytes1, bytes2);
+}
+#endif
+#elif defined(__SSSE3__)
+// horizontally add 4x4 floats
+static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
+ __m128 res_0 =_mm_hadd_ps(a, b);
+ __m128 res_1 =_mm_hadd_ps(c, d);
+ __m128 res =_mm_hadd_ps(res_0, res_1);
+ res =_mm_hadd_ps(res, res);
+ res =_mm_hadd_ps(res, res);
+
+ return _mm_cvtss_f32(res);
+}
+#endif // __AVX__ || __AVX2__ || __AVX512F__
+#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+
+#if defined(__ARM_NEON)
+
+#if !defined(__aarch64__)
+
+inline static int32_t vaddvq_s32(int32x4_t v) {
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
+}
+
+inline static float vaddvq_f32(float32x4_t v) {
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
+}
+
+inline static float vmaxvq_f32(float32x4_t v) {
+ return
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
+}
+
+inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
+ int32x4_t res;
+
+ res[0] = roundf(vgetq_lane_f32(v, 0));
+ res[1] = roundf(vgetq_lane_f32(v, 1));
+ res[2] = roundf(vgetq_lane_f32(v, 2));
+ res[3] = roundf(vgetq_lane_f32(v, 3));
+
+ return res;
+}
+
+#endif
+#endif
+
+#if defined(__ARM_NEON) || defined(__wasm_simd128__)
+#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
+#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
+#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
+#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
+#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
+#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
+#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
+#define B8(c,s ) B7(c,s, c), B7(c,s, s)
+
+// precomputed tables for expanding 8bits to 8 bytes:
+static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
+static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
+#endif
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
+ static const int qk = QK4_0;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < qk; j++) {
+ const float v = x[i*qk + j];
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -8;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = ggml_fp32_to_fp16(d);
+
+ for (int j = 0; j < qk/2; ++j) {
+ const float x0 = x[i*qk + 0 + j]*id;
+ const float x1 = x[i*qk + qk/2 + j]*id;
+
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+ y[i].qs[j] = xi0;
+ y[i].qs[j] |= xi1 << 4;
+ }
+ }
+}
+
+void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
+ quantize_row_q4_0_reference(x, y, k);
+}
+
+void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
+ const int qk = QK4_1;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ float min = FLT_MAX;
+ float max = -FLT_MAX;
+
+ for (int j = 0; j < qk; j++) {
+ const float v = x[i*qk + j];
+
+ if (v < min) min = v;
+ if (v > max) max = v;
+ }
+
+ const float d = (max - min) / ((1 << 4) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = ggml_fp32_to_fp16(d);
+ y[i].m = ggml_fp32_to_fp16(min);
+
+ for (int j = 0; j < qk/2; ++j) {
+ const float x0 = (x[i*qk + 0 + j] - min)*id;
+ const float x1 = (x[i*qk + qk/2 + j] - min)*id;
+
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+ y[i].qs[j] = xi0;
+ y[i].qs[j] |= xi1 << 4;
+ }
+ }
+}
+
+void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
+ quantize_row_q4_1_reference(x, y, k);
+}
+
+void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
+ static const int qk = QK5_0;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+ float max = 0.0f;
+
+ for (int j = 0; j < qk; j++) {
+ const float v = x[i*qk + j];
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ max = v;
+ }
+ }
+
+ const float d = max / -16;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = ggml_fp32_to_fp16(d);
+
+ uint32_t qh = 0;
+
+ for (int j = 0; j < qk/2; ++j) {
+ const float x0 = x[i*qk + 0 + j]*id;
+ const float x1 = x[i*qk + qk/2 + j]*id;
+
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+ y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+ // get the 5-th bit and store it in qh at the right position
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
+ }
+
+ memcpy(&y[i].qh, &qh, sizeof(qh));
+ }
+}
+
+void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) {
+ quantize_row_q5_0_reference(x, y, k);
+}
+
+void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
+ const int qk = QK5_1;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ float min = FLT_MAX;
+ float max = -FLT_MAX;
+
+ for (int j = 0; j < qk; j++) {
+ const float v = x[i*qk + j];
+
+ if (v < min) min = v;
+ if (v > max) max = v;
+ }
+
+ const float d = (max - min) / ((1 << 5) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = ggml_fp32_to_fp16(d);
+ y[i].m = ggml_fp32_to_fp16(min);
+
+ uint32_t qh = 0;
+
+ for (int j = 0; j < qk/2; ++j) {
+ const float x0 = (x[i*qk + 0 + j] - min)*id;
+ const float x1 = (x[i*qk + qk/2 + j] - min)*id;
+
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+ y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+ // get the 5-th bit and store it in qh at the right position
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
+ }
+
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
+ }
+}
+
+void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) {
+ quantize_row_q5_1_reference(x, y, k);
+}
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
+ assert(k % QK8_0 == 0);
+ const int nb = k / QK8_0;
+
+ for (int i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_0; j++) {
+ const float v = x[i*QK8_0 + j];
+ amax = MAX(amax, fabsf(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = ggml_fp32_to_fp16(d);
+
+ for (int j = 0; j < QK8_0; ++j) {
+ const float x0 = x[i*QK8_0 + j]*id;
+
+ y[i].qs[j] = roundf(x0);
+ }
+ }
+}
+
+void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
+ assert(QK8_0 == 32);
+ assert(k % QK8_0 == 0);
+ const int nb = k / QK8_0;
+
+ block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+ for (int i = 0; i < nb; i++) {
+ float32x4_t srcv [8];
+ float32x4_t asrcv[8];
+ float32x4_t amaxv[8];
+
+ for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+ const float amax = vmaxvq_f32(amaxv[0]);
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = ggml_fp32_to_fp16(d);
+
+ for (int j = 0; j < 8; j++) {
+ const float32x4_t v = vmulq_n_f32(srcv[j], id);
+ const int32x4_t vi = vcvtnq_s32_f32(v);
+
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+ }
+ }
+#elif defined(__wasm_simd128__)
+ for (int i = 0; i < nb; i++) {
+ v128_t srcv [8];
+ v128_t asrcv[8];
+ v128_t amaxv[8];
+
+ for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
+ for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
+ for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
+ for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
+
+ const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
+ wasm_f32x4_extract_lane(amaxv[0], 1)),
+ MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
+ wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = ggml_fp32_to_fp16(d);
+
+ for (int j = 0; j < 8; j++) {
+ const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
+ const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
+
+ y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
+ y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
+ y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
+ y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
+ }
+ }
+#elif defined(__AVX2__) || defined(__AVX__)
+ for (int i = 0; i < nb; i++) {
+ // Load elements into 4 AVX vectors
+ __m256 v0 = _mm256_loadu_ps( x );
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
+ x += 32;
+
+ // Compute max(abs(e)) for the block
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+ const float maxScalar = _mm_cvtss_f32( max4 );
+
+ // Quantize these floats
+ const float d = maxScalar / 127.f;
+ y[i].d = ggml_fp32_to_fp16(d);
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+ const __m256 mul = _mm256_set1_ps( id );
+
+ // Apply the multiplier
+ v0 = _mm256_mul_ps( v0, mul );
+ v1 = _mm256_mul_ps( v1, mul );
+ v2 = _mm256_mul_ps( v2, mul );
+ v3 = _mm256_mul_ps( v3, mul );
+
+ // Round to nearest integer
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+ // Convert floats to integers
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+ // Convert int32 to int16
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
+ // Convert int16 to int8
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+ // We got our precious signed bytes, but the order is now wrong
+ // These AVX2 pack instructions process 16-byte pieces independently
+ // The following instruction is fixing the order
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+#else
+ // Since we don't have in AVX some necessary functions,
+ // we split the registers in half and call AVX2 analogs from SSE
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+ // Convert int32 to int16
+ ni0 = _mm_packs_epi32( ni0, ni1 );
+ ni2 = _mm_packs_epi32( ni2, ni3 );
+ ni4 = _mm_packs_epi32( ni4, ni5 );
+ ni6 = _mm_packs_epi32( ni6, ni7 );
+ // Convert int16 to int8
+ ni0 = _mm_packs_epi16( ni0, ni2 );
+ ni4 = _mm_packs_epi16( ni4, ni6 );
+
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+ }
+#elif defined(__riscv_v_intrinsic)
+
+ size_t vl = __riscv_vsetvl_e32m4(QK8_0);
+
+ for (int i = 0; i < nb; i++) {
+ // load elements
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
+
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = ggml_fp32_to_fp16(d);
+
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+ // convert to integer
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+ // store result
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+ }
+#else
+ // scalar
+ quantize_row_q8_0_reference(x, y, k);
+#endif
+}
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
+ assert(QK8_1 == 32);
+ assert(k % QK8_1 == 0);
+ const int nb = k / QK8_1;
+
+ for (int i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_1; j++) {
+ const float v = x[i*QK8_1 + j];
+ amax = MAX(amax, fabsf(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = d;
+
+ int sum = 0;
+
+ for (int j = 0; j < QK8_1/2; ++j) {
+ const float v0 = x[i*QK8_1 + j]*id;
+ const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;
+
+ y[i].qs[ j] = roundf(v0);
+ y[i].qs[QK8_1/2 + j] = roundf(v1);
+
+ sum += y[i].qs[ j];
+ sum += y[i].qs[QK8_1/2 + j];
+ }
+
+ y[i].s = sum*d;
+ }
+}
+
+void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
+ assert(k % QK8_1 == 0);
+ const int nb = k / QK8_1;
+
+ block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+ for (int i = 0; i < nb; i++) {
+ float32x4_t srcv [8];
+ float32x4_t asrcv[8];
+ float32x4_t amaxv[8];
+
+ for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+ const float amax = vmaxvq_f32(amaxv[0]);
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = d;
+
+ int32x4_t accv = vdupq_n_s32(0);
+
+ for (int j = 0; j < 8; j++) {
+ const float32x4_t v = vmulq_n_f32(srcv[j], id);
+ const int32x4_t vi = vcvtnq_s32_f32(v);
+
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+
+ accv = vaddq_s32(accv, vi);
+ }
+
+ y[i].s = d * vaddvq_s32(accv);
+ }
+#elif defined(__wasm_simd128__)
+ for (int i = 0; i < nb; i++) {
+ v128_t srcv [8];
+ v128_t asrcv[8];
+ v128_t amaxv[8];
+
+ for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
+ for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
+
+ for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
+ for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
+ for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
+
+ const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
+ wasm_f32x4_extract_lane(amaxv[0], 1)),
+ MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
+ wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = d;
+
+ v128_t accv = wasm_i32x4_splat(0);
+
+ for (int j = 0; j < 8; j++) {
+ const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
+ const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
+
+ y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
+ y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
+ y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
+ y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
+
+ accv = wasm_i32x4_add(accv, vi);
+ }
+
+ y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) +
+ wasm_i32x4_extract_lane(accv, 1) +
+ wasm_i32x4_extract_lane(accv, 2) +
+ wasm_i32x4_extract_lane(accv, 3));
+ }
+#elif defined(__AVX2__) || defined(__AVX__)
+ for (int i = 0; i < nb; i++) {
+ // Load elements into 4 AVX vectors
+ __m256 v0 = _mm256_loadu_ps( x );
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
+ x += 32;
+
+ // Compute max(abs(e)) for the block
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+ const float maxScalar = _mm_cvtss_f32( max4 );
+
+ // Quantize these floats
+ const float d = maxScalar / 127.f;
+ y[i].d = d;
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+ const __m256 mul = _mm256_set1_ps( id );
+
+ // Apply the multiplier
+ v0 = _mm256_mul_ps( v0, mul );
+ v1 = _mm256_mul_ps( v1, mul );
+ v2 = _mm256_mul_ps( v2, mul );
+ v3 = _mm256_mul_ps( v3, mul );
+
+ // Round to nearest integer
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+ // Convert floats to integers
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+ // Compute the sum of the quants and set y[i].s
+ y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
+
+ // Convert int32 to int16
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
+ // Convert int16 to int8
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+ // We got our precious signed bytes, but the order is now wrong
+ // These AVX2 pack instructions process 16-byte pieces independently
+ // The following instruction is fixing the order
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+#else
+ // Since we don't have in AVX some necessary functions,
+ // we split the registers in half and call AVX2 analogs from SSE
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+ // Compute the sum of the quants and set y[i].s
+ const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
+ const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
+ y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1));
+
+ // Convert int32 to int16
+ ni0 = _mm_packs_epi32( ni0, ni1 );
+ ni2 = _mm_packs_epi32( ni2, ni3 );
+ ni4 = _mm_packs_epi32( ni4, ni5 );
+ ni6 = _mm_packs_epi32( ni6, ni7 );
+ // Convert int16 to int8
+ ni0 = _mm_packs_epi16( ni0, ni2 );
+ ni4 = _mm_packs_epi16( ni4, ni6 );
+
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+ }
+#elif defined(__riscv_v_intrinsic)
+
+ size_t vl = __riscv_vsetvl_e32m4(QK8_1);
+
+ for (int i = 0; i < nb; i++) {
+ // load elements
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
+
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = d;
+
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+ // convert to integer
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+ // store result
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+
+ // compute sum for y[i].s
+ vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
+ vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
+
+ // set y[i].s
+ int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
+ y[i].s = sum*d;
+ }
+#else
+ // scalar
+ quantize_row_q8_1_reference(x, y, k);
+#endif
+}
+
+void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
+ static const int qk = QK4_0;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ const float d = ggml_fp16_to_fp32(x[i].d);
+
+ for (int j = 0; j < qk/2; ++j) {
+ const int x0 = (x[i].qs[j] & 0x0F) - 8;
+ const int x1 = (x[i].qs[j] >> 4) - 8;
+
+ y[i*qk + j + 0 ] = x0*d;
+ y[i*qk + j + qk/2] = x1*d;
+ }
+ }
+}
+
+void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
+ static const int qk = QK4_1;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ const float d = ggml_fp16_to_fp32(x[i].d);
+ const float m = ggml_fp16_to_fp32(x[i].m);
+
+ for (int j = 0; j < qk/2; ++j) {
+ const int x0 = (x[i].qs[j] & 0x0F);
+ const int x1 = (x[i].qs[j] >> 4);
+
+ y[i*qk + j + 0 ] = x0*d + m;
+ y[i*qk + j + qk/2] = x1*d + m;
+ }
+ }
+}
+
+void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
+ static const int qk = QK5_0;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ const float d = ggml_fp16_to_fp32(x[i].d);
+
+ uint32_t qh;
+ memcpy(&qh, x[i].qh, sizeof(qh));
+
+ for (int j = 0; j < qk/2; ++j) {
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
+
+ const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
+ const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
+
+ y[i*qk + j + 0 ] = x0*d;
+ y[i*qk + j + qk/2] = x1*d;
+ }
+ }
+}
+
+void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) {
+ static const int qk = QK5_1;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ const float d = ggml_fp16_to_fp32(x[i].d);
+ const float m = ggml_fp16_to_fp32(x[i].m);
+
+ uint32_t qh;
+ memcpy(&qh, x[i].qh, sizeof(qh));
+
+ for (int j = 0; j < qk/2; ++j) {
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
+
+ const int x0 = (x[i].qs[j] & 0x0F) | xh_0;
+ const int x1 = (x[i].qs[j] >> 4) | xh_1;
+
+ y[i*qk + j + 0 ] = x0*d + m;
+ y[i*qk + j + qk/2] = x1*d + m;
+ }
+ }
+}
+
+void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k) {
+ static const int qk = QK8_0;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ const float d = ggml_fp16_to_fp32(x[i].d);
+
+ for (int j = 0; j < qk; ++j) {
+ y[i*qk + j] = x[i].qs[j]*d;
+ }
+ }
+}
+
//
// 2-6 bit quantization in super-blocks
//
@@ -1264,15 +2283,6 @@ void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
//
#if __AVX__ || __AVX2__ || __AVX512F__
-// horizontally add 8 floats
-static inline float hsum_float_8(const __m256 x) {
- __m128 res = _mm256_extractf128_ps(x, 1);
- res = _mm_add_ps(res, _mm256_castps256_ps128(x));
- res = _mm_add_ps(res, _mm_movehl_ps(res, res));
- res = _mm_add_ss(res, _mm_movehdup_ps(res));
- return _mm_cvtss_f32(res);
-}
-
// shuffles to pick the required scales in dot products
static inline __m256i get_scale_shuffle_q3k(int i) {
static const uint8_t k_shuffle[128] = {
@@ -1311,6 +2321,1224 @@ static inline __m128i get_scale_shuffle(int i) {
}
#endif
+void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+
+ assert(n % qk == 0);
+
+ const block_q4_0 * restrict x = vx;
+ const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+ assert(nb % 2 == 0); // TODO: handle odd nb
+
+ for (int i = 0; i < nb; i += 2) {
+ const block_q4_0 * restrict x0 = &x[i + 0];
+ const block_q4_0 * restrict x1 = &x[i + 1];
+ const block_q8_0 * restrict y0 = &y[i + 0];
+ const block_q8_0 * restrict y1 = &y[i + 1];
+
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
+ const int8x16_t s8b = vdupq_n_s8(0x8);
+
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+ // 4-bit -> 8-bit
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+ // sub 8
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
+
+ // load y
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ // dot product into int32x4_t
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
+
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), ggml_fp16_to_fp32(x0->d)*ggml_fp16_to_fp32(y0->d));
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), ggml_fp16_to_fp32(x1->d)*ggml_fp16_to_fp32(y1->d));
+#else
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
+
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
+
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), ggml_fp16_to_fp32(x0->d)*ggml_fp16_to_fp32(y0->d));
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), ggml_fp16_to_fp32(x1->d)*ggml_fp16_to_fp32(y1->d));
+#endif
+ }
+
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+
+ // Main loop
+ for (int i = 0; i < nb; ++i) {
+ /* Compute combined scale for the block */
+ const __m256 d = _mm256_set1_ps( ggml_fp16_to_fp32(x[i].d) * ggml_fp16_to_fp32(y[i].d) );
+
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
+
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+ const __m256i off = _mm256_set1_epi8( 8 );
+ bx = _mm256_sub_epi8( bx, off );
+
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+ /* Multiply q with scale and accumulate */
+ acc = _mm256_fmadd_ps( d, q, acc );
+ }
+
+ *s = hsum_float_8(acc);
+#elif defined(__AVX__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+
+ // Main loop
+ for (int i = 0; i < nb; ++i) {
+ // Compute combined scale for the block
+ const __m256 d = _mm256_set1_ps( ggml_fp16_to_fp32(x[i].d) * ggml_fp16_to_fp32(y[i].d) );
+
+ const __m128i lowMask = _mm_set1_epi8(0xF);
+ const __m128i off = _mm_set1_epi8(8);
+
+ const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
+
+ __m128i bx = _mm_and_si128(lowMask, tmp);
+ __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
+ bx = _mm_sub_epi8(bx, off);
+ const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
+
+ bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
+ by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
+ bx = _mm_sub_epi8(bx, off);
+ const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
+
+ // Convert int32_t to float
+ __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
+
+ // Apply the scale, and accumulate
+ acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
+ }
+
+ *s = hsum_float_8(acc);
+#elif defined(__SSSE3__)
+ // set constants
+ const __m128i lowMask = _mm_set1_epi8(0xF);
+ const __m128i off = _mm_set1_epi8(8);
+
+ // Initialize accumulator with zeros
+ __m128 acc_0 = _mm_setzero_ps();
+ __m128 acc_1 = _mm_setzero_ps();
+ __m128 acc_2 = _mm_setzero_ps();
+ __m128 acc_3 = _mm_setzero_ps();
+
+ // First round without accumulation
+ {
+ _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
+ _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
+
+ // Compute combined scale for the block 0 and 1
+ const __m128 d_0_1 = _mm_set1_ps( ggml_fp16_to_fp32(x[0].d) * ggml_fp16_to_fp32(y[0].d) );
+
+ const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
+
+ __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+ __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
+ bx_0 = _mm_sub_epi8(bx_0, off);
+ const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+ __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+ __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
+ bx_1 = _mm_sub_epi8(bx_1, off);
+ const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+ _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
+ _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
+
+ // Compute combined scale for the block 2 and 3
+ const __m128 d_2_3 = _mm_set1_ps( ggml_fp16_to_fp32(x[1].d) * ggml_fp16_to_fp32(y[1].d) );
+
+ const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
+
+ __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+ __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
+ bx_2 = _mm_sub_epi8(bx_2, off);
+ const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+ __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+ __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
+ bx_3 = _mm_sub_epi8(bx_3, off);
+ const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+ // Convert int32_t to float
+ __m128 p0 = _mm_cvtepi32_ps(i32_0);
+ __m128 p1 = _mm_cvtepi32_ps(i32_1);
+ __m128 p2 = _mm_cvtepi32_ps(i32_2);
+ __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+ // Apply the scale
+ acc_0 = _mm_mul_ps( d_0_1, p0 );
+ acc_1 = _mm_mul_ps( d_0_1, p1 );
+ acc_2 = _mm_mul_ps( d_2_3, p2 );
+ acc_3 = _mm_mul_ps( d_2_3, p3 );
+ }
+
+ assert(nb % 2 == 0); // TODO: handle odd nb
+
+ // Main loop
+ for (int i = 2; i < nb; i+=2) {
+ _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
+ _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
+
+ // Compute combined scale for the block 0 and 1
+ const __m128 d_0_1 = _mm_set1_ps( ggml_fp16_to_fp32(x[i].d) * ggml_fp16_to_fp32(y[i].d) );
+
+ const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
+
+ __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+ __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
+ bx_0 = _mm_sub_epi8(bx_0, off);
+ const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+ __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+ __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
+ bx_1 = _mm_sub_epi8(bx_1, off);
+ const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+ _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
+ _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
+
+ // Compute combined scale for the block 2 and 3
+ const __m128 d_2_3 = _mm_set1_ps( ggml_fp16_to_fp32(x[i + 1].d) * ggml_fp16_to_fp32(y[i + 1].d) );
+
+ const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
+
+ __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+ __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
+ bx_2 = _mm_sub_epi8(bx_2, off);
+ const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+ __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+ __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
+ bx_3 = _mm_sub_epi8(bx_3, off);
+ const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+ // Convert int32_t to float
+ __m128 p0 = _mm_cvtepi32_ps(i32_0);
+ __m128 p1 = _mm_cvtepi32_ps(i32_1);
+ __m128 p2 = _mm_cvtepi32_ps(i32_2);
+ __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+ // Apply the scale
+ __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
+ __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
+ __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
+ __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
+
+ // Acummulate
+ acc_0 = _mm_add_ps(p0_d, acc_0);
+ acc_1 = _mm_add_ps(p1_d, acc_1);
+ acc_2 = _mm_add_ps(p2_d, acc_2);
+ acc_3 = _mm_add_ps(p3_d, acc_3);
+ }
+
+ *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
+#elif defined(__riscv_v_intrinsic)
+ float sumf = 0.0;
+
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+ for (int i = 0; i < nb; i++) {
+ // load elements
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
+
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
+
+ // mask and store lower part of x, and then upper part
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+ // subtract offset
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
+
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+ sumf += sumi*ggml_fp16_to_fp32(x[i].d)*ggml_fp16_to_fp32(y[i].d);
+ }
+
+ *s = sumf;
+#else
+ // scalar
+ float sumf = 0.0;
+
+ for (int i = 0; i < nb; i++) {
+ int sumi = 0;
+
+ for (int j = 0; j < qk/2; ++j) {
+ const int v0 = (x[i].qs[j] & 0x0F) - 8;
+ const int v1 = (x[i].qs[j] >> 4) - 8;
+
+ sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
+ }
+
+ sumf += sumi*ggml_fp16_to_fp32(x[i].d)*ggml_fp16_to_fp32(y[i].d);
+ }
+
+ *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ const int qk = QK8_1;
+ const int nb = n / qk;
+
+ assert(n % qk == 0);
+
+ const block_q4_1 * restrict x = vx;
+ const block_q8_1 * restrict y = vy;
+
+ // TODO: add WASM SIMD
+#if defined(__ARM_NEON)
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+ float summs = 0;
+
+ assert(nb % 2 == 0); // TODO: handle odd nb
+
+ for (int i = 0; i < nb; i += 2) {
+ const block_q4_1 * restrict x0 = &x[i + 0];
+ const block_q4_1 * restrict x1 = &x[i + 1];
+ const block_q8_1 * restrict y0 = &y[i + 0];
+ const block_q8_1 * restrict y1 = &y[i + 1];
+
+ summs += ggml_fp16_to_fp32(x0->m) * y0->s + ggml_fp16_to_fp32(x1->m) * y1->s;
+
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+ // 4-bit -> 8-bit
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+ // load y
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ // dot product into int32x4_t
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
+
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), ggml_fp16_to_fp32(x0->d)*y0->d);
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), ggml_fp16_to_fp32(x1->d)*y1->d);
+#else
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
+
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
+
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), ggml_fp16_to_fp32(x0->d)*y0->d);
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), ggml_fp16_to_fp32(x1->d)*y1->d);
+#endif
+ }
+
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
+#elif defined(__AVX2__) || defined(__AVX__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+
+ float summs = 0;
+
+ // Main loop
+ for (int i = 0; i < nb; ++i) {
+ const float d0 = ggml_fp16_to_fp32(x[i].d);
+ const float d1 = y[i].d;
+
+ summs += ggml_fp16_to_fp32(x[i].m) * y[i].s;
+
+ const __m256 d0v = _mm256_set1_ps( d0 );
+ const __m256 d1v = _mm256_set1_ps( d1 );
+
+ // Compute combined scales
+ const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
+
+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+ const __m256i bx = bytes_from_nibbles_32(x[i].qs);
+ const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
+
+ const __m256 xy = mul_sum_us8_pairs_float(bx, by);
+
+ // Accumulate d0*d1*x*y
+#if defined(__AVX2__)
+ acc = _mm256_fmadd_ps( d0d1, xy, acc );
+#else
+ acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
+#endif
+ }
+
+ *s = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+ float sumf = 0.0;
+
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+ for (int i = 0; i < nb; i++) {
+ // load elements
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
+
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
+
+ // mask and store lower part of x, and then upper part
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+ sumf += (ggml_fp16_to_fp32(x[i].d)*y[i].d)*sumi + ggml_fp16_to_fp32(x[i].m)*y[i].s;
+ }
+
+ *s = sumf;
+#else
+ // scalar
+ float sumf = 0.0;
+
+ for (int i = 0; i < nb; i++) {
+ int sumi = 0;
+
+ for (int j = 0; j < qk/2; ++j) {
+ const int v0 = (x[i].qs[j] & 0x0F);
+ const int v1 = (x[i].qs[j] >> 4);
+
+ sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
+ }
+
+ sumf += (ggml_fp16_to_fp32(x[i].d)*y[i].d)*sumi + ggml_fp16_to_fp32(x[i].m)*y[i].s;
+ }
+
+ *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+
+ assert(n % qk == 0);
+ assert(qk == QK5_0);
+
+ const block_q5_0 * restrict x = vx;
+ const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+ uint32_t qh0;
+ uint32_t qh1;
+
+ uint64_t tmp0[4];
+ uint64_t tmp1[4];
+
+ assert(nb % 2 == 0); // TODO: handle odd nb
+
+ for (int i = 0; i < nb; i += 2) {
+ const block_q5_0 * restrict x0 = &x[i];
+ const block_q5_0 * restrict x1 = &x[i + 1];
+ const block_q8_0 * restrict y0 = &y[i];
+ const block_q8_0 * restrict y1 = &y[i + 1];
+
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+ // extract the 5th bit via lookup table ((!b) << 4)
+ memcpy(&qh0, x0->qh, sizeof(qh0));
+ memcpy(&qh1, x1->qh, sizeof(qh1));
+
+ tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
+ tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
+ tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
+ tmp0[3] = table_b2b_1[(qh0 >> 24) ];
+
+ tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
+ tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
+ tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
+ tmp1[3] = table_b2b_1[(qh1 >> 24) ];
+
+ const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
+ const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
+ const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
+ const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
+
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+ // 4-bit -> 8-bit
+ int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
+ int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+ int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
+ int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+ // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
+ const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
+ const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
+ const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
+ const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
+
+ // load y
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+ vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
+ vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), ggml_fp16_to_fp32(x0->d)*ggml_fp16_to_fp32(y0->d));
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+ vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
+ vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), ggml_fp16_to_fp32(x1->d)*ggml_fp16_to_fp32(y1->d));
+#else
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
+
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
+
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), ggml_fp16_to_fp32(x0->d)*ggml_fp16_to_fp32(y0->d));
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), ggml_fp16_to_fp32(x1->d)*ggml_fp16_to_fp32(y1->d));
+#endif
+ }
+
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__wasm_simd128__)
+ v128_t sumv = wasm_f32x4_splat(0.0f);
+
+ uint32_t qh;
+ uint64_t tmp[4];
+
+ // TODO: check if unrolling this is better
+ for (int i = 0; i < nb; ++i) {
+ const block_q5_0 * restrict x0 = &x[i];
+ const block_q8_0 * restrict y0 = &y[i];
+
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
+
+ // extract the 5th bit
+ memcpy(&qh, x0->qh, sizeof(qh));
+
+ tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
+ tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
+ tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
+ tmp[3] = table_b2b_1[(qh >> 24) ];
+
+ const v128_t qhl = wasm_v128_load(tmp + 0);
+ const v128_t qhh = wasm_v128_load(tmp + 2);
+
+ const v128_t v0 = wasm_v128_load(x0->qs);
+
+ // 4-bit -> 8-bit
+ const v128_t v0l = wasm_v128_and (v0, m4b);
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
+
+ // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
+ const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
+ const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
+
+ // load y
+ const v128_t v1l = wasm_v128_load(y0->qs);
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
+
+ // int8x16 -> int16x8
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
+
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
+
+ // dot product
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
+ wasm_i32x4_add(
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+ wasm_f32x4_splat(ggml_fp16_to_fp32(x0->d) * ggml_fp16_to_fp32(y0->d))));
+ }
+
+ *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
+#elif defined(__AVX2__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+
+ // Main loop
+ for (int i = 0; i < nb; i++) {
+ /* Compute combined scale for the block */
+ const __m256 d = _mm256_set1_ps(ggml_fp16_to_fp32(x[i].d) * ggml_fp16_to_fp32(y[i].d));
+
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
+ bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
+ bx = _mm256_or_si256(bx, bxhi);
+
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+ /* Multiply q with scale and accumulate */
+ acc = _mm256_fmadd_ps(d, q, acc);
+ }
+
+ *s = hsum_float_8(acc);
+#elif defined(__AVX__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+ __m128i mask = _mm_set1_epi8((char)0xF0);
+
+ // Main loop
+ for (int i = 0; i < nb; i++) {
+ /* Compute combined scale for the block */
+ const __m256 d = _mm256_set1_ps(ggml_fp16_to_fp32(x[i].d) * ggml_fp16_to_fp32(y[i].d));
+
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
+ const __m256i bxhi = bytes_from_bits_32(x[i].qh);
+ __m128i bxhil = _mm256_castsi256_si128(bxhi);
+ __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+ bxhil = _mm_andnot_si128(bxhil, mask);
+ bxhih = _mm_andnot_si128(bxhih, mask);
+ __m128i bxl = _mm256_castsi256_si128(bx);
+ __m128i bxh = _mm256_extractf128_si256(bx, 1);
+ bxl = _mm_or_si128(bxl, bxhil);
+ bxh = _mm_or_si128(bxh, bxhih);
+ bx = MM256_SET_M128I(bxh, bxl);
+
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+ /* Multiply q with scale and accumulate */
+ acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
+ }
+
+ *s = hsum_float_8(acc);
+#elif defined(__riscv_v_intrinsic)
+ float sumf = 0.0;
+
+ uint32_t qh;
+
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+ // These tempory registers are for masking and shift operations
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+ vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
+
+ vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
+ vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+
+ for (int i = 0; i < nb; i++) {
+ memcpy(&qh, x[i].qh, sizeof(uint32_t));
+
+ // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+
+ // ((qh & (1u << (j + 16))) >> (j + 12));
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
+ vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
+
+ // narrowing
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
+
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
+
+ // load
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
+
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
+
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
+
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
+
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+ sumf += (ggml_fp16_to_fp32(x[i].d)*ggml_fp16_to_fp32(y[i].d)) * sumi;
+ }
+
+ *s = sumf;
+#else
+ // scalar
+ float sumf = 0.0;
+
+ for (int i = 0; i < nb; i++) {
+ uint32_t qh;
+ memcpy(&qh, x[i].qh, sizeof(qh));
+
+ int sumi = 0;
+
+ for (int j = 0; j < qk/2; ++j) {
+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
+ const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
+
+ const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
+ const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
+
+ sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
+ }
+
+ sumf += (ggml_fp16_to_fp32(x[i].d)*ggml_fp16_to_fp32(y[i].d)) * sumi;
+ }
+
+ *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ const int qk = QK8_1;
+ const int nb = n / qk;
+
+ assert(n % qk == 0);
+ assert(qk == QK5_1);
+
+ const block_q5_1 * restrict x = vx;
+ const block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+ float summs0 = 0.0f;
+ float summs1 = 0.0f;
+
+ uint32_t qh0;
+ uint32_t qh1;
+
+ uint64_t tmp0[4];
+ uint64_t tmp1[4];
+
+ assert(nb % 2 == 0); // TODO: handle odd nb
+
+ for (int i = 0; i < nb; i += 2) {
+ const block_q5_1 * restrict x0 = &x[i];
+ const block_q5_1 * restrict x1 = &x[i + 1];
+ const block_q8_1 * restrict y0 = &y[i];
+ const block_q8_1 * restrict y1 = &y[i + 1];
+
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+ summs0 += ggml_fp16_to_fp32(x0->m) * y0->s;
+ summs1 += ggml_fp16_to_fp32(x1->m) * y1->s;
+
+ // extract the 5th bit via lookup table ((b) << 4)
+ memcpy(&qh0, x0->qh, sizeof(qh0));
+ memcpy(&qh1, x1->qh, sizeof(qh1));
+
+ tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
+ tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
+ tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
+ tmp0[3] = table_b2b_0[(qh0 >> 24) ];
+
+ tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
+ tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
+ tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
+ tmp1[3] = table_b2b_0[(qh1 >> 24) ];
+
+ const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
+ const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
+ const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
+ const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
+
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+ // 4-bit -> 8-bit
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+ // add high bit
+ const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
+ const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
+ const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
+ const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
+
+ // load y
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+ vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
+ vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), ggml_fp16_to_fp32(x0->d)*y0->d);
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+ vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
+ vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), ggml_fp16_to_fp32(x1->d)*y1->d);
+#else
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
+
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
+
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), ggml_fp16_to_fp32(x0->d)*y0->d);
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), ggml_fp16_to_fp32(x1->d)*y1->d);
+#endif
+ }
+
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
+#elif defined(__wasm_simd128__)
+ v128_t sumv = wasm_f32x4_splat(0.0f);
+
+ float summs = 0.0f;
+
+ uint32_t qh;
+ uint64_t tmp[4];
+
+ // TODO: check if unrolling this is better
+ for (int i = 0; i < nb; ++i) {
+ const block_q5_1 * restrict x0 = &x[i];
+ const block_q8_1 * restrict y0 = &y[i];
+
+ summs += ggml_fp16_to_fp32(x0->m) * y0->s;
+
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
+
+ // extract the 5th bit
+ memcpy(&qh, x0->qh, sizeof(qh));
+
+ tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
+ tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
+ tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
+ tmp[3] = table_b2b_0[(qh >> 24) ];
+
+ const v128_t qhl = wasm_v128_load(tmp + 0);
+ const v128_t qhh = wasm_v128_load(tmp + 2);
+
+ const v128_t v0 = wasm_v128_load(x0->qs);
+
+ // 4-bit -> 8-bit
+ const v128_t v0l = wasm_v128_and (v0, m4b);
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
+
+ // add high bit
+ const v128_t v0lf = wasm_v128_or(v0l, qhl);
+ const v128_t v0hf = wasm_v128_or(v0h, qhh);
+
+ // load y
+ const v128_t v1l = wasm_v128_load(y0->qs);
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
+
+ // int8x16 -> int16x8
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
+
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
+
+ // dot product
+ sumv = wasm_f32x4_add(sumv,
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+ wasm_f32x4_splat(ggml_fp16_to_fp32(x0->d) * y0->d)));
+ }
+
+ *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
+#elif defined(__AVX2__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+
+ float summs = 0.0f;
+
+ // Main loop
+ for (int i = 0; i < nb; i++) {
+ const __m256 dx = _mm256_set1_ps(ggml_fp16_to_fp32(x[i].d));
+
+ summs += ggml_fp16_to_fp32(x[i].m) * y[i].s;
+
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
+ bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
+ bx = _mm256_or_si256(bx, bxhi);
+
+ const __m256 dy = _mm256_set1_ps(y[i].d);
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+ const __m256 q = mul_sum_us8_pairs_float(bx, by);
+
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
+ }
+
+ *s = hsum_float_8(acc) + summs;
+#elif defined(__AVX__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+ __m128i mask = _mm_set1_epi8(0x10);
+
+ float summs = 0.0f;
+
+ // Main loop
+ for (int i = 0; i < nb; i++) {
+ const __m256 dx = _mm256_set1_ps(ggml_fp16_to_fp32(x[i].d));
+
+ summs += ggml_fp16_to_fp32(x[i].m) * y[i].s;
+
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
+ const __m256i bxhi = bytes_from_bits_32(x[i].qh);
+ __m128i bxhil = _mm256_castsi256_si128(bxhi);
+ __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+ bxhil = _mm_and_si128(bxhil, mask);
+ bxhih = _mm_and_si128(bxhih, mask);
+ __m128i bxl = _mm256_castsi256_si128(bx);
+ __m128i bxh = _mm256_extractf128_si256(bx, 1);
+ bxl = _mm_or_si128(bxl, bxhil);
+ bxh = _mm_or_si128(bxh, bxhih);
+ bx = MM256_SET_M128I(bxh, bxl);
+
+ const __m256 dy = _mm256_set1_ps(y[i].d);
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+ const __m256 q = mul_sum_us8_pairs_float(bx, by);
+
+ acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
+ }
+
+ *s = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+ float sumf = 0.0;
+
+ uint32_t qh;
+
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+ // temporary registers for shift operations
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+ vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+
+ for (int i = 0; i < nb; i++) {
+ memcpy(&qh, x[i].qh, sizeof(uint32_t));
+
+ // load qh
+ vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
+
+ // ((qh >> (j + 0)) << 4) & 0x10;
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
+
+ // ((qh >> (j + 12)) ) & 0x10;
+ vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
+
+ // narrowing
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
+
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
+
+ // load
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
+
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
+
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
+
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+ sumf += (ggml_fp16_to_fp32(x[i].d)*y[i].d)*sumi + ggml_fp16_to_fp32(x[i].m)*y[i].s;
+ }
+
+ *s = sumf;
+#else
+ // scalar
+ float sumf = 0.0;
+
+ for (int i = 0; i < nb; i++) {
+ uint32_t qh;
+ memcpy(&qh, x[i].qh, sizeof(qh));
+
+ int sumi = 0;
+
+ for (int j = 0; j < qk/2; ++j) {
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
+
+ const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
+ const int32_t x1 = (x[i].qs[j] >> 4) | xh_1;
+
+ sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
+ }
+
+ sumf += (ggml_fp16_to_fp32(x[i].d)*y[i].d)*sumi + ggml_fp16_to_fp32(x[i].m)*y[i].s;
+ }
+
+ *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+
+ assert(n % qk == 0);
+
+ const block_q8_0 * restrict x = vx;
+ const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+ assert(nb % 2 == 0); // TODO: handle odd nb
+
+ for (int i = 0; i < nb; i += 2) {
+ const block_q8_0 * restrict x0 = &x[i + 0];
+ const block_q8_0 * restrict x1 = &x[i + 1];
+ const block_q8_0 * restrict y0 = &y[i + 0];
+ const block_q8_0 * restrict y1 = &y[i + 1];
+
+ const int8x16_t x0_0 = vld1q_s8(x0->qs);
+ const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
+ const int8x16_t x1_0 = vld1q_s8(x1->qs);
+ const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
+
+ // load y
+ const int8x16_t y0_0 = vld1q_s8(y0->qs);
+ const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
+ const int8x16_t y1_0 = vld1q_s8(y1->qs);
+ const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+ vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
+ vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), ggml_fp16_to_fp32(x0->d)*ggml_fp16_to_fp32(y0->d));
+
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+ vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
+ vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), ggml_fp16_to_fp32(x1->d)*ggml_fp16_to_fp32(y1->d));
+
+#else
+ const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
+ const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
+ const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
+ const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
+
+ const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
+ const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
+ const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
+ const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
+
+ const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
+ const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
+ const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
+ const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
+
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), ggml_fp16_to_fp32(x0->d)*ggml_fp16_to_fp32(y0->d));
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), ggml_fp16_to_fp32(x1->d)*ggml_fp16_to_fp32(y1->d));
+#endif
+ }
+
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__) || defined(__AVX__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+
+ // Main loop
+ for (int i = 0; i < nb; ++i) {
+ // Compute combined scale for the block
+ const __m256 d = _mm256_set1_ps(ggml_fp16_to_fp32(x[i].d) * ggml_fp16_to_fp32(y[i].d));
+ __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+ // Multiply q with scale and accumulate
+#if defined(__AVX2__)
+ acc = _mm256_fmadd_ps( d, q, acc );
+#else
+ acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
+#endif
+ }
+
+ *s = hsum_float_8(acc);
+#elif defined(__riscv_v_intrinsic)
+ float sumf = 0.0;
+ size_t vl = __riscv_vsetvl_e8m1(qk);
+
+ for (int i = 0; i < nb; i++) {
+ // load elements
+ vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
+ vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
+
+ vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
+
+ vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
+ vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
+
+ int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
+
+ sumf += sumi*(ggml_fp16_to_fp32(x[i].d)*ggml_fp16_to_fp32(y[i].d));
+ }
+
+ *s = sumf;
+#else
+ // scalar
+ float sumf = 0.0;
+
+ for (int i = 0; i < nb; i++) {
+ int sumi = 0;
+
+ for (int j = 0; j < qk; j++) {
+ sumi += x[i].qs[j]*y[i].qs[j];
+ }
+
+ sumf += sumi*(ggml_fp16_to_fp32(x[i].d)*ggml_fp16_to_fp32(y[i].d));
+ }
+
+ *s = sumf;
+#endif
+}
+
#if QK_K == 256
void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
diff --git a/k_quants.h b/ggml-quants.h
index 9de089e7..d88f99e3 100644
--- a/k_quants.h
+++ b/ggml-quants.h
@@ -1,20 +1,14 @@
#pragma once
+// This is a private API for quantization and dequantization
+// Should not be used directly, use ggml.h instead
+
#include "ggml.h"
#include <stdint.h>
#include <assert.h>
#include <stddef.h>
-// Super-block size
-#ifdef GGML_QKK_64
-#define QK_K 64
-#define K_SCALE_SIZE 4
-#else
-#define QK_K 256
-#define K_SCALE_SIZE 12
-#endif
-
#ifndef static_assert
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
#define static_assert(cond, msg) _Static_assert(cond, msg)
@@ -23,10 +17,66 @@
#endif
#endif
+#define QK4_0 32
+typedef struct {
+ ggml_fp16_t d; // delta
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
+
+#define QK4_1 32
+typedef struct {
+ ggml_fp16_t d; // delta
+ ggml_fp16_t m; // min
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
+} block_q4_1;
+static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
+
+#define QK5_0 32
+typedef struct {
+ ggml_fp16_t d; // delta
+ uint8_t qh[4]; // 5-th bit of quants
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
+
+#define QK5_1 32
+typedef struct {
+ ggml_fp16_t d; // delta
+ ggml_fp16_t m; // min
+ uint8_t qh[4]; // 5-th bit of quants
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
+} block_q5_1;
+static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
+
+#define QK8_0 32
+typedef struct {
+ ggml_fp16_t d; // delta
+ int8_t qs[QK8_0]; // quants
+} block_q8_0;
+static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
+
+#define QK8_1 32
+typedef struct {
+ float d; // delta
+ float s; // d * sum(qs[i])
+ int8_t qs[QK8_1]; // quants
+} block_q8_1;
+static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
+
//
// Super-block quantization structures
//
+// Super-block size
+#ifdef GGML_QKK_64
+#define QK_K 64
+#define K_SCALE_SIZE 4
+#else
+#define QK_K 256
+#define K_SCALE_SIZE 12
+#endif
+
// 2-bit quantization
// weight is represented as x = a * q + b
// 16 blocks of 16 elements each
@@ -127,6 +177,13 @@ static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_
// Quantization
+void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
+void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
+void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k);
+void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k);
+void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k);
+void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k);
+
void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k);
void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k);
void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k);
@@ -134,6 +191,13 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
+void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
+void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
+void quantize_row_q5_0(const float * restrict x, void * restrict y, int k);
+void quantize_row_q5_1(const float * restrict x, void * restrict y, int k);
+void quantize_row_q8_0(const float * restrict x, void * restrict y, int k);
+void quantize_row_q8_1(const float * restrict x, void * restrict y, int k);
+
void quantize_row_q2_K(const float * restrict x, void * restrict y, int k);
void quantize_row_q3_K(const float * restrict x, void * restrict y, int k);
void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
@@ -142,6 +206,13 @@ void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
// Dequantization
+void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
+void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k);
+void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k);
+void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k);
+void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k);
+//void dequantize_row_q8_1(const block_q8_1 * restrict x, float * restrict y, int k);
+
void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k);
void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k);
void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k);
@@ -150,16 +221,14 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
// Dot product
+void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+
void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-
-// Quantization with histogram collection
-size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
-size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
-size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
-size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
-size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
-
diff --git a/ggml.c b/ggml.c
index 6f66bab0..95f72c35 100644
--- a/ggml.c
+++ b/ggml.c
@@ -1,10 +1,7 @@
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
#include "ggml.h"
-
-#ifdef GGML_USE_K_QUANTS
-#include "k_quants.h"
-#endif
+#include "ggml-quants.h"
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
@@ -443,21 +440,6 @@ static ggml_fp16_t table_exp_f16[1 << 16];
// precomputed f32 table for f16 (256 KB)
static float table_f32_f16[1 << 16];
-#if defined(__ARM_NEON) || defined(__wasm_simd128__)
-#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
-#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
-#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
-#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
-#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
-#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
-#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
-#define B8(c,s ) B7(c,s, c), B7(c,s, s)
-
-// precomputed tables for expanding 8bits to 8 bytes:
-static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
-static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
-#endif
-
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
// This is also true for POWER9.
@@ -587,1071 +569,8 @@ int64_t ggml_cycles_per_ms(void) {
static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
-//
-// quantization
-//
-
-#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
-
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
-// multiply int8_t, add results pairwise twice
-static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
- // Get absolute values of x vectors
- const __m128i ax = _mm_sign_epi8(x, x);
- // Sign the values of the y vectors
- const __m128i sy = _mm_sign_epi8(y, x);
- // Perform multiplication and create 16-bit values
- const __m128i dot = _mm_maddubs_epi16(ax, sy);
- const __m128i ones = _mm_set1_epi16(1);
- return _mm_madd_epi16(ones, dot);
-}
-
-#if __AVX__ || __AVX2__ || __AVX512F__
-// horizontally add 8 floats
-static inline float hsum_float_8(const __m256 x) {
- __m128 res = _mm256_extractf128_ps(x, 1);
- res = _mm_add_ps(res, _mm256_castps256_ps128(x));
- res = _mm_add_ps(res, _mm_movehl_ps(res, res));
- res = _mm_add_ss(res, _mm_movehdup_ps(res));
- return _mm_cvtss_f32(res);
-}
-
-// horizontally add 8 int32_t
-static inline int hsum_i32_8(const __m256i a) {
- const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
- const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
- const __m128i sum64 = _mm_add_epi32(hi64, sum128);
- const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
- return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
-}
-
-// horizontally add 4 int32_t
-static inline int hsum_i32_4(const __m128i a) {
- const __m128i hi64 = _mm_unpackhi_epi64(a, a);
- const __m128i sum64 = _mm_add_epi32(hi64, a);
- const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
- return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
-}
-
-#if defined(__AVX2__) || defined(__AVX512F__)
-// spread 32 bits to 32 bytes { 0x00, 0xFF }
-static inline __m256i bytes_from_bits_32(const uint8_t * x) {
- uint32_t x32;
- memcpy(&x32, x, sizeof(uint32_t));
- const __m256i shuf_mask = _mm256_set_epi64x(
- 0x0303030303030303, 0x0202020202020202,
- 0x0101010101010101, 0x0000000000000000);
- __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
- const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
- bytes = _mm256_or_si256(bytes, bit_mask);
- return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
-}
-
-// Unpack 32 4-bit fields into 32 bytes
-// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
-static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
-{
- const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
- const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
- const __m256i lowMask = _mm256_set1_epi8( 0xF );
- return _mm256_and_si256(lowMask, bytes);
-}
-
-// add int16_t pairwise and return as float vector
-static inline __m256 sum_i16_pairs_float(const __m256i x) {
- const __m256i ones = _mm256_set1_epi16(1);
- const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
- return _mm256_cvtepi32_ps(summed_pairs);
-}
-
-static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
-#if __AVXVNNI__
- const __m256i zero = _mm256_setzero_si256();
- const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
- return _mm256_cvtepi32_ps(summed_pairs);
-#else
- // Perform multiplication and create 16-bit values
- const __m256i dot = _mm256_maddubs_epi16(ax, sy);
- return sum_i16_pairs_float(dot);
-#endif
-}
-
-// multiply int8_t, add results pairwise twice and return as float vector
-static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
-#if __AVXVNNIINT8__
- const __m256i zero = _mm256_setzero_si256();
- const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
- return _mm256_cvtepi32_ps(summed_pairs);
-#else
- // Get absolute values of x vectors
- const __m256i ax = _mm256_sign_epi8(x, x);
- // Sign the values of the y vectors
- const __m256i sy = _mm256_sign_epi8(y, x);
- return mul_sum_us8_pairs_float(ax, sy);
-#endif
-}
-
-static inline __m128i packNibbles( __m256i bytes )
-{
- // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
-#if __AVX512F__
- const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
- bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
- return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
-#else
- const __m256i lowByte = _mm256_set1_epi16( 0xFF );
- __m256i high = _mm256_andnot_si256( lowByte, bytes );
- __m256i low = _mm256_and_si256( lowByte, bytes );
- high = _mm256_srli_epi16( high, 4 );
- bytes = _mm256_or_si256( low, high );
-
- // Compress uint16_t lanes into bytes
- __m128i r0 = _mm256_castsi256_si128( bytes );
- __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
- return _mm_packus_epi16( r0, r1 );
-#endif
-}
-#elif defined(__AVX__)
-// spread 32 bits to 32 bytes { 0x00, 0xFF }
-static inline __m256i bytes_from_bits_32(const uint8_t * x) {
- uint32_t x32;
- memcpy(&x32, x, sizeof(uint32_t));
- const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
- const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
- __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
- __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
- const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
- bytesl = _mm_or_si128(bytesl, bit_mask);
- bytesh = _mm_or_si128(bytesh, bit_mask);
- bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
- bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
- return MM256_SET_M128I(bytesh, bytesl);
-}
-
-// Unpack 32 4-bit fields into 32 bytes
-// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
-static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
-{
- // Load 16 bytes from memory
- __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
- __m128i tmph = _mm_srli_epi16(tmpl, 4);
- const __m128i lowMask = _mm_set1_epi8(0xF);
- tmpl = _mm_and_si128(lowMask, tmpl);
- tmph = _mm_and_si128(lowMask, tmph);
- return MM256_SET_M128I(tmph, tmpl);
-}
-
-// add int16_t pairwise and return as float vector
-static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
- const __m128i ones = _mm_set1_epi16(1);
- const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
- const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
- const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
- return _mm256_cvtepi32_ps(summed_pairs);
-}
-
-static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
- const __m128i axl = _mm256_castsi256_si128(ax);
- const __m128i axh = _mm256_extractf128_si256(ax, 1);
- const __m128i syl = _mm256_castsi256_si128(sy);
- const __m128i syh = _mm256_extractf128_si256(sy, 1);
- // Perform multiplication and create 16-bit values
- const __m128i dotl = _mm_maddubs_epi16(axl, syl);
- const __m128i doth = _mm_maddubs_epi16(axh, syh);
- return sum_i16_pairs_float(doth, dotl);
-}
-
-// multiply int8_t, add results pairwise twice and return as float vector
-static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
- const __m128i xl = _mm256_castsi256_si128(x);
- const __m128i xh = _mm256_extractf128_si256(x, 1);
- const __m128i yl = _mm256_castsi256_si128(y);
- const __m128i yh = _mm256_extractf128_si256(y, 1);
- // Get absolute values of x vectors
- const __m128i axl = _mm_sign_epi8(xl, xl);
- const __m128i axh = _mm_sign_epi8(xh, xh);
- // Sign the values of the y vectors
- const __m128i syl = _mm_sign_epi8(yl, xl);
- const __m128i syh = _mm_sign_epi8(yh, xh);
- // Perform multiplication and create 16-bit values
- const __m128i dotl = _mm_maddubs_epi16(axl, syl);
- const __m128i doth = _mm_maddubs_epi16(axh, syh);
- return sum_i16_pairs_float(doth, dotl);
-}
-
-static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
-{
- // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
- const __m128i lowByte = _mm_set1_epi16( 0xFF );
- __m128i high = _mm_andnot_si128( lowByte, bytes1 );
- __m128i low = _mm_and_si128( lowByte, bytes1 );
- high = _mm_srli_epi16( high, 4 );
- bytes1 = _mm_or_si128( low, high );
- high = _mm_andnot_si128( lowByte, bytes2 );
- low = _mm_and_si128( lowByte, bytes2 );
- high = _mm_srli_epi16( high, 4 );
- bytes2 = _mm_or_si128( low, high );
-
- return _mm_packus_epi16( bytes1, bytes2);
-}
-#endif
-#elif defined(__SSSE3__)
-// horizontally add 4x4 floats
-static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
- __m128 res_0 =_mm_hadd_ps(a, b);
- __m128 res_1 =_mm_hadd_ps(c, d);
- __m128 res =_mm_hadd_ps(res_0, res_1);
- res =_mm_hadd_ps(res, res);
- res =_mm_hadd_ps(res, res);
-
- return _mm_cvtss_f32(res);
-}
-#endif // __AVX__ || __AVX2__ || __AVX512F__
-#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
-
-#if defined(__ARM_NEON)
-
-#if !defined(__aarch64__)
-
-inline static int32_t vaddvq_s32(int32x4_t v) {
- return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
-}
-
-inline static float vaddvq_f32(float32x4_t v) {
- return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
-}
-
-inline static float vmaxvq_f32(float32x4_t v) {
- return
- MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
- MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
-}
-
-inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
- int32x4_t res;
-
- res[0] = roundf(vgetq_lane_f32(v, 0));
- res[1] = roundf(vgetq_lane_f32(v, 1));
- res[2] = roundf(vgetq_lane_f32(v, 2));
- res[3] = roundf(vgetq_lane_f32(v, 3));
-
- return res;
-}
-
-#endif
-#endif
-
-#define QK4_0 32
-typedef struct {
- ggml_fp16_t d; // delta
- uint8_t qs[QK4_0 / 2]; // nibbles / quants
-} block_q4_0;
-static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
-
-#define QK4_1 32
-typedef struct {
- ggml_fp16_t d; // delta
- ggml_fp16_t m; // min
- uint8_t qs[QK4_1 / 2]; // nibbles / quants
-} block_q4_1;
-static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
-
-#define QK5_0 32
-typedef struct {
- ggml_fp16_t d; // delta
- uint8_t qh[4]; // 5-th bit of quants
- uint8_t qs[QK5_0 / 2]; // nibbles / quants
-} block_q5_0;
-static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
-
-#define QK5_1 32
-typedef struct {
- ggml_fp16_t d; // delta
- ggml_fp16_t m; // min
- uint8_t qh[4]; // 5-th bit of quants
- uint8_t qs[QK5_1 / 2]; // nibbles / quants
-} block_q5_1;
-static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
-
-#define QK8_0 32
-typedef struct {
- ggml_fp16_t d; // delta
- int8_t qs[QK8_0]; // quants
-} block_q8_0;
-static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
-
-#define QK8_1 32
-typedef struct {
- float d; // delta
- float s; // d * sum(qs[i])
- int8_t qs[QK8_1]; // quants
-} block_q8_1;
-static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
-
-// reference implementation for deterministic creation of model files
-static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
- static const int qk = QK4_0;
-
- assert(k % qk == 0);
-
- const int nb = k / qk;
-
- for (int i = 0; i < nb; i++) {
- float amax = 0.0f; // absolute max
- float max = 0.0f;
-
- for (int j = 0; j < qk; j++) {
- const float v = x[i*qk + j];
- if (amax < fabsf(v)) {
- amax = fabsf(v);
- max = v;
- }
- }
-
- const float d = max / -8;
- const float id = d ? 1.0f/d : 0.0f;
-
- y[i].d = GGML_FP32_TO_FP16(d);
-
- for (int j = 0; j < qk/2; ++j) {
- const float x0 = x[i*qk + 0 + j]*id;
- const float x1 = x[i*qk + qk/2 + j]*id;
-
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
-
- y[i].qs[j] = xi0;
- y[i].qs[j] |= xi1 << 4;
- }
- }
-}
-
-static void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
- quantize_row_q4_0_reference(x, y, k);
-}
-
-static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
- const int qk = QK4_1;
-
- assert(k % qk == 0);
-
- const int nb = k / qk;
-
- for (int i = 0; i < nb; i++) {
- float min = FLT_MAX;
- float max = -FLT_MAX;
-
- for (int j = 0; j < qk; j++) {
- const float v = x[i*qk + j];
-
- if (v < min) min = v;
- if (v > max) max = v;
- }
-
- const float d = (max - min) / ((1 << 4) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- y[i].d = GGML_FP32_TO_FP16(d);
- y[i].m = GGML_FP32_TO_FP16(min);
-
- for (int j = 0; j < qk/2; ++j) {
- const float x0 = (x[i*qk + 0 + j] - min)*id;
- const float x1 = (x[i*qk + qk/2 + j] - min)*id;
-
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
-
- y[i].qs[j] = xi0;
- y[i].qs[j] |= xi1 << 4;
- }
- }
-}
-
-static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
- quantize_row_q4_1_reference(x, y, k);
-}
-
-static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
- static const int qk = QK5_0;
-
- assert(k % qk == 0);
-
- const int nb = k / qk;
-
- for (int i = 0; i < nb; i++) {
- float amax = 0.0f; // absolute max
- float max = 0.0f;
-
- for (int j = 0; j < qk; j++) {
- const float v = x[i*qk + j];
- if (amax < fabsf(v)) {
- amax = fabsf(v);
- max = v;
- }
- }
-
- const float d = max / -16;
- const float id = d ? 1.0f/d : 0.0f;
-
- y[i].d = GGML_FP32_TO_FP16(d);
-
- uint32_t qh = 0;
-
- for (int j = 0; j < qk/2; ++j) {
- const float x0 = x[i*qk + 0 + j]*id;
- const float x1 = x[i*qk + qk/2 + j]*id;
-
- const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
- const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
-
- y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
-
- // get the 5-th bit and store it in qh at the right position
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
- qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
- }
-
- memcpy(&y[i].qh, &qh, sizeof(qh));
- }
-}
-
-static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) {
- quantize_row_q5_0_reference(x, y, k);
-}
-
-static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
- const int qk = QK5_1;
-
- assert(k % qk == 0);
-
- const int nb = k / qk;
-
- for (int i = 0; i < nb; i++) {
- float min = FLT_MAX;
- float max = -FLT_MAX;
-
- for (int j = 0; j < qk; j++) {
- const float v = x[i*qk + j];
-
- if (v < min) min = v;
- if (v > max) max = v;
- }
-
- const float d = (max - min) / ((1 << 5) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- y[i].d = GGML_FP32_TO_FP16(d);
- y[i].m = GGML_FP32_TO_FP16(min);
-
- uint32_t qh = 0;
-
- for (int j = 0; j < qk/2; ++j) {
- const float x0 = (x[i*qk + 0 + j] - min)*id;
- const float x1 = (x[i*qk + qk/2 + j] - min)*id;
-
- const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
- const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
-
- y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
-
- // get the 5-th bit and store it in qh at the right position
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
- qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
- }
-
- memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
- }
-}
-
-static void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) {
- quantize_row_q5_1_reference(x, y, k);
-}
-
-// reference implementation for deterministic creation of model files
-static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
- assert(k % QK8_0 == 0);
- const int nb = k / QK8_0;
-
- for (int i = 0; i < nb; i++) {
- float amax = 0.0f; // absolute max
-
- for (int j = 0; j < QK8_0; j++) {
- const float v = x[i*QK8_0 + j];
- amax = MAX(amax, fabsf(v));
- }
-
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- y[i].d = GGML_FP32_TO_FP16(d);
-
- for (int j = 0; j < QK8_0; ++j) {
- const float x0 = x[i*QK8_0 + j]*id;
-
- y[i].qs[j] = roundf(x0);
- }
- }
-}
-
-static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
- assert(QK8_0 == 32);
- assert(k % QK8_0 == 0);
- const int nb = k / QK8_0;
-
- block_q8_0 * restrict y = vy;
-
-#if defined(__ARM_NEON)
- for (int i = 0; i < nb; i++) {
- float32x4_t srcv [8];
- float32x4_t asrcv[8];
- float32x4_t amaxv[8];
-
- for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
- for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
-
- for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
- for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
- for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
-
- const float amax = vmaxvq_f32(amaxv[0]);
-
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- y[i].d = GGML_FP32_TO_FP16(d);
-
- for (int j = 0; j < 8; j++) {
- const float32x4_t v = vmulq_n_f32(srcv[j], id);
- const int32x4_t vi = vcvtnq_s32_f32(v);
-
- y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
- y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
- y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
- y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
- }
- }
-#elif defined(__wasm_simd128__)
- for (int i = 0; i < nb; i++) {
- v128_t srcv [8];
- v128_t asrcv[8];
- v128_t amaxv[8];
-
- for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
- for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
-
- for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
- for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
- for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
-
- const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
- wasm_f32x4_extract_lane(amaxv[0], 1)),
- MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
- wasm_f32x4_extract_lane(amaxv[0], 3)));
-
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- y[i].d = GGML_FP32_TO_FP16(d);
-
- for (int j = 0; j < 8; j++) {
- const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
- const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
-
- y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
- y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
- y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
- y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
- }
- }
-#elif defined(__AVX2__) || defined(__AVX__)
- for (int i = 0; i < nb; i++) {
- // Load elements into 4 AVX vectors
- __m256 v0 = _mm256_loadu_ps( x );
- __m256 v1 = _mm256_loadu_ps( x + 8 );
- __m256 v2 = _mm256_loadu_ps( x + 16 );
- __m256 v3 = _mm256_loadu_ps( x + 24 );
- x += 32;
-
- // Compute max(abs(e)) for the block
- const __m256 signBit = _mm256_set1_ps( -0.0f );
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
-
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
- max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
- max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
- const float maxScalar = _mm_cvtss_f32( max4 );
-
- // Quantize these floats
- const float d = maxScalar / 127.f;
- y[i].d = GGML_FP32_TO_FP16(d);
- const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
- const __m256 mul = _mm256_set1_ps( id );
-
- // Apply the multiplier
- v0 = _mm256_mul_ps( v0, mul );
- v1 = _mm256_mul_ps( v1, mul );
- v2 = _mm256_mul_ps( v2, mul );
- v3 = _mm256_mul_ps( v3, mul );
-
- // Round to nearest integer
- v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
- v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
- v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
- v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
-
- // Convert floats to integers
- __m256i i0 = _mm256_cvtps_epi32( v0 );
- __m256i i1 = _mm256_cvtps_epi32( v1 );
- __m256i i2 = _mm256_cvtps_epi32( v2 );
- __m256i i3 = _mm256_cvtps_epi32( v3 );
-
-#if defined(__AVX2__)
- // Convert int32 to int16
- i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
- i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
- // Convert int16 to int8
- i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
-
- // We got our precious signed bytes, but the order is now wrong
- // These AVX2 pack instructions process 16-byte pieces independently
- // The following instruction is fixing the order
- const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
- i0 = _mm256_permutevar8x32_epi32( i0, perm );
-
- _mm256_storeu_si256((__m256i *)y[i].qs, i0);
-#else
- // Since we don't have in AVX some necessary functions,
- // we split the registers in half and call AVX2 analogs from SSE
- __m128i ni0 = _mm256_castsi256_si128( i0 );
- __m128i ni1 = _mm256_extractf128_si256( i0, 1);
- __m128i ni2 = _mm256_castsi256_si128( i1 );
- __m128i ni3 = _mm256_extractf128_si256( i1, 1);
- __m128i ni4 = _mm256_castsi256_si128( i2 );
- __m128i ni5 = _mm256_extractf128_si256( i2, 1);
- __m128i ni6 = _mm256_castsi256_si128( i3 );
- __m128i ni7 = _mm256_extractf128_si256( i3, 1);
-
- // Convert int32 to int16
- ni0 = _mm_packs_epi32( ni0, ni1 );
- ni2 = _mm_packs_epi32( ni2, ni3 );
- ni4 = _mm_packs_epi32( ni4, ni5 );
- ni6 = _mm_packs_epi32( ni6, ni7 );
- // Convert int16 to int8
- ni0 = _mm_packs_epi16( ni0, ni2 );
- ni4 = _mm_packs_epi16( ni4, ni6 );
-
- _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
- _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
-#endif
- }
-#elif defined(__riscv_v_intrinsic)
-
- size_t vl = __riscv_vsetvl_e32m4(QK8_0);
-
- for (int i = 0; i < nb; i++) {
- // load elements
- vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
-
- vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
- vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
- vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
- float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
-
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- y[i].d = GGML_FP32_TO_FP16(d);
-
- vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
-
- // convert to integer
- vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
- vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
-
- // store result
- __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
- }
-#else
- // scalar
- quantize_row_q8_0_reference(x, y, k);
-#endif
-}
-
-// reference implementation for deterministic creation of model files
-static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
- assert(QK8_1 == 32);
- assert(k % QK8_1 == 0);
- const int nb = k / QK8_1;
-
- for (int i = 0; i < nb; i++) {
- float amax = 0.0f; // absolute max
-
- for (int j = 0; j < QK8_1; j++) {
- const float v = x[i*QK8_1 + j];
- amax = MAX(amax, fabsf(v));
- }
-
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- y[i].d = d;
-
- int sum = 0;
-
- for (int j = 0; j < QK8_1/2; ++j) {
- const float v0 = x[i*QK8_1 + j]*id;
- const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;
-
- y[i].qs[ j] = roundf(v0);
- y[i].qs[QK8_1/2 + j] = roundf(v1);
-
- sum += y[i].qs[ j];
- sum += y[i].qs[QK8_1/2 + j];
- }
-
- y[i].s = sum*d;
- }
-}
-
-static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
- assert(k % QK8_1 == 0);
- const int nb = k / QK8_1;
-
- block_q8_1 * restrict y = vy;
-
-#if defined(__ARM_NEON)
- for (int i = 0; i < nb; i++) {
- float32x4_t srcv [8];
- float32x4_t asrcv[8];
- float32x4_t amaxv[8];
-
- for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
- for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
-
- for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
- for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
- for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
-
- const float amax = vmaxvq_f32(amaxv[0]);
-
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- y[i].d = d;
-
- int32x4_t accv = vdupq_n_s32(0);
-
- for (int j = 0; j < 8; j++) {
- const float32x4_t v = vmulq_n_f32(srcv[j], id);
- const int32x4_t vi = vcvtnq_s32_f32(v);
-
- y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
- y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
- y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
- y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
-
- accv = vaddq_s32(accv, vi);
- }
-
- y[i].s = d * vaddvq_s32(accv);
- }
-#elif defined(__wasm_simd128__)
- for (int i = 0; i < nb; i++) {
- v128_t srcv [8];
- v128_t asrcv[8];
- v128_t amaxv[8];
-
- for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
- for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
-
- for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
- for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
- for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
-
- const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
- wasm_f32x4_extract_lane(amaxv[0], 1)),
- MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
- wasm_f32x4_extract_lane(amaxv[0], 3)));
-
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- y[i].d = d;
-
- v128_t accv = wasm_i32x4_splat(0);
-
- for (int j = 0; j < 8; j++) {
- const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
- const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
-
- y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
- y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
- y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
- y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
-
- accv = wasm_i32x4_add(accv, vi);
- }
-
- y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) +
- wasm_i32x4_extract_lane(accv, 1) +
- wasm_i32x4_extract_lane(accv, 2) +
- wasm_i32x4_extract_lane(accv, 3));
- }
-#elif defined(__AVX2__) || defined(__AVX__)
- for (int i = 0; i < nb; i++) {
- // Load elements into 4 AVX vectors
- __m256 v0 = _mm256_loadu_ps( x );
- __m256 v1 = _mm256_loadu_ps( x + 8 );
- __m256 v2 = _mm256_loadu_ps( x + 16 );
- __m256 v3 = _mm256_loadu_ps( x + 24 );
- x += 32;
-
- // Compute max(abs(e)) for the block
- const __m256 signBit = _mm256_set1_ps( -0.0f );
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
-
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
- max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
- max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
- const float maxScalar = _mm_cvtss_f32( max4 );
-
- // Quantize these floats
- const float d = maxScalar / 127.f;
- y[i].d = d;
- const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
- const __m256 mul = _mm256_set1_ps( id );
-
- // Apply the multiplier
- v0 = _mm256_mul_ps( v0, mul );
- v1 = _mm256_mul_ps( v1, mul );
- v2 = _mm256_mul_ps( v2, mul );
- v3 = _mm256_mul_ps( v3, mul );
-
- // Round to nearest integer
- v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
- v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
- v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
- v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
-
- // Convert floats to integers
- __m256i i0 = _mm256_cvtps_epi32( v0 );
- __m256i i1 = _mm256_cvtps_epi32( v1 );
- __m256i i2 = _mm256_cvtps_epi32( v2 );
- __m256i i3 = _mm256_cvtps_epi32( v3 );
-
-#if defined(__AVX2__)
- // Compute the sum of the quants and set y[i].s
- y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
-
- // Convert int32 to int16
- i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
- i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
- // Convert int16 to int8
- i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
-
- // We got our precious signed bytes, but the order is now wrong
- // These AVX2 pack instructions process 16-byte pieces independently
- // The following instruction is fixing the order
- const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
- i0 = _mm256_permutevar8x32_epi32( i0, perm );
-
- _mm256_storeu_si256((__m256i *)y[i].qs, i0);
-#else
- // Since we don't have in AVX some necessary functions,
- // we split the registers in half and call AVX2 analogs from SSE
- __m128i ni0 = _mm256_castsi256_si128( i0 );
- __m128i ni1 = _mm256_extractf128_si256( i0, 1);
- __m128i ni2 = _mm256_castsi256_si128( i1 );
- __m128i ni3 = _mm256_extractf128_si256( i1, 1);
- __m128i ni4 = _mm256_castsi256_si128( i2 );
- __m128i ni5 = _mm256_extractf128_si256( i2, 1);
- __m128i ni6 = _mm256_castsi256_si128( i3 );
- __m128i ni7 = _mm256_extractf128_si256( i3, 1);
-
- // Compute the sum of the quants and set y[i].s
- const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
- const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
- y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1));
-
- // Convert int32 to int16
- ni0 = _mm_packs_epi32( ni0, ni1 );
- ni2 = _mm_packs_epi32( ni2, ni3 );
- ni4 = _mm_packs_epi32( ni4, ni5 );
- ni6 = _mm_packs_epi32( ni6, ni7 );
- // Convert int16 to int8
- ni0 = _mm_packs_epi16( ni0, ni2 );
- ni4 = _mm_packs_epi16( ni4, ni6 );
-
- _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
- _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
-#endif
- }
-#elif defined(__riscv_v_intrinsic)
-
- size_t vl = __riscv_vsetvl_e32m4(QK8_1);
-
- for (int i = 0; i < nb; i++) {
- // load elements
- vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
-
- vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
- vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
- vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
- float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
-
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- y[i].d = d;
-
- vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
-
- // convert to integer
- vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
- vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
-
- // store result
- __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
-
- // compute sum for y[i].s
- vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
- vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
-
- // set y[i].s
- int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
- y[i].s = sum*d;
- }
-#else
- // scalar
- quantize_row_q8_1_reference(x, y, k);
-#endif
-}
-
-static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
- static const int qk = QK4_0;
-
- assert(k % qk == 0);
-
- const int nb = k / qk;
-
- for (int i = 0; i < nb; i++) {
- const float d = GGML_FP16_TO_FP32(x[i].d);
-
- for (int j = 0; j < qk/2; ++j) {
- const int x0 = (x[i].qs[j] & 0x0F) - 8;
- const int x1 = (x[i].qs[j] >> 4) - 8;
-
- y[i*qk + j + 0 ] = x0*d;
- y[i*qk + j + qk/2] = x1*d;
- }
- }
-}
-
-static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
- static const int qk = QK4_1;
-
- assert(k % qk == 0);
-
- const int nb = k / qk;
-
- for (int i = 0; i < nb; i++) {
- const float d = GGML_FP16_TO_FP32(x[i].d);
- const float m = GGML_FP16_TO_FP32(x[i].m);
-
- for (int j = 0; j < qk/2; ++j) {
- const int x0 = (x[i].qs[j] & 0x0F);
- const int x1 = (x[i].qs[j] >> 4);
-
- y[i*qk + j + 0 ] = x0*d + m;
- y[i*qk + j + qk/2] = x1*d + m;
- }
- }
-}
-
-static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
- static const int qk = QK5_0;
-
- assert(k % qk == 0);
-
- const int nb = k / qk;
-
- for (int i = 0; i < nb; i++) {
- const float d = GGML_FP16_TO_FP32(x[i].d);
-
- uint32_t qh;
- memcpy(&qh, x[i].qh, sizeof(qh));
-
- for (int j = 0; j < qk/2; ++j) {
- const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
- const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
-
- const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
- const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
-
- y[i*qk + j + 0 ] = x0*d;
- y[i*qk + j + qk/2] = x1*d;
- }
- }
-}
-
-static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) {
- static const int qk = QK5_1;
-
- assert(k % qk == 0);
-
- const int nb = k / qk;
-
- for (int i = 0; i < nb; i++) {
- const float d = GGML_FP16_TO_FP32(x[i].d);
- const float m = GGML_FP16_TO_FP32(x[i].m);
-
- uint32_t qh;
- memcpy(&qh, x[i].qh, sizeof(qh));
-
- for (int j = 0; j < qk/2; ++j) {
- const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
- const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
-
- const int x0 = (x[i].qs[j] & 0x0F) | xh_0;
- const int x1 = (x[i].qs[j] >> 4) | xh_1;
-
- y[i*qk + j + 0 ] = x0*d + m;
- y[i*qk + j + qk/2] = x1*d + m;
- }
- }
-}
-
-static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
- static const int qk = QK8_0;
-
- assert(k % qk == 0);
-
- const int nb = k / qk;
-
- const block_q8_0 * restrict x = vx;
-
- for (int i = 0; i < nb; i++) {
- const float d = GGML_FP16_TO_FP32(x[i].d);
-
- for (int j = 0; j < qk; ++j) {
- y[i*qk + j] = x[i].qs[j]*d;
- }
- }
-}
-
static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
-static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
[GGML_TYPE_I8] = {
@@ -1740,7 +659,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.blck_size = QK8_0,
.type_size = sizeof(block_q8_0),
.is_quantized = true,
- .to_float = dequantize_row_q8_0,
+ .to_float = (ggml_to_float_t) dequantize_row_q8_0,
.from_float = quantize_row_q8_0,
.from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
.vec_dot = ggml_vec_dot_q8_0_q8_0,
@@ -1755,7 +674,6 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference,
.vec_dot_type = GGML_TYPE_Q8_1,
},
-#ifdef GGML_USE_K_QUANTS
[GGML_TYPE_Q2_K] = {
.type_name = "q2_K",
.blck_size = QK_K,
@@ -1818,7 +736,6 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.is_quantized = true,
.from_float = quantize_row_q8_K,
}
-#endif
};
// For internal test use
@@ -2442,1218 +1359,6 @@ static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * rest
*s = sumf;
}
-static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
- const int qk = QK8_0;
- const int nb = n / qk;
-
- assert(n % qk == 0);
-
- const block_q4_0 * restrict x = vx;
- const block_q8_0 * restrict y = vy;
-
-#if defined(__ARM_NEON)
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
-
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
- for (int i = 0; i < nb; i += 2) {
- const block_q4_0 * restrict x0 = &x[i + 0];
- const block_q4_0 * restrict x1 = &x[i + 1];
- const block_q8_0 * restrict y0 = &y[i + 0];
- const block_q8_0 * restrict y1 = &y[i + 1];
-
- const uint8x16_t m4b = vdupq_n_u8(0x0F);
- const int8x16_t s8b = vdupq_n_s8(0x8);
-
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
-
- // 4-bit -> 8-bit
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
-
- // sub 8
- const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
- const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
- const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
- const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
-
- // load y
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
-
-#if defined(__ARM_FEATURE_DOTPROD)
- // dot product into int32x4_t
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
-
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
-#else
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
-
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
-
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
-
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
-#endif
- }
-
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
-#elif defined(__AVX2__)
- // Initialize accumulator with zeros
- __m256 acc = _mm256_setzero_ps();
-
- // Main loop
- for (int i = 0; i < nb; ++i) {
- /* Compute combined scale for the block */
- const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
-
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
-
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
- const __m256i off = _mm256_set1_epi8( 8 );
- bx = _mm256_sub_epi8( bx, off );
-
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
-
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
-
- /* Multiply q with scale and accumulate */
- acc = _mm256_fmadd_ps( d, q, acc );
- }
-
- *s = hsum_float_8(acc);
-#elif defined(__AVX__)
- // Initialize accumulator with zeros
- __m256 acc = _mm256_setzero_ps();
-
- // Main loop
- for (int i = 0; i < nb; ++i) {
- // Compute combined scale for the block
- const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
-
- const __m128i lowMask = _mm_set1_epi8(0xF);
- const __m128i off = _mm_set1_epi8(8);
-
- const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
-
- __m128i bx = _mm_and_si128(lowMask, tmp);
- __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
- bx = _mm_sub_epi8(bx, off);
- const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
-
- bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
- by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
- bx = _mm_sub_epi8(bx, off);
- const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
-
- // Convert int32_t to float
- __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
-
- // Apply the scale, and accumulate
- acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
- }
-
- *s = hsum_float_8(acc);
-#elif defined(__SSSE3__)
- // set constants
- const __m128i lowMask = _mm_set1_epi8(0xF);
- const __m128i off = _mm_set1_epi8(8);
-
- // Initialize accumulator with zeros
- __m128 acc_0 = _mm_setzero_ps();
- __m128 acc_1 = _mm_setzero_ps();
- __m128 acc_2 = _mm_setzero_ps();
- __m128 acc_3 = _mm_setzero_ps();
-
- // First round without accumulation
- {
- _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
- _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
-
- // Compute combined scale for the block 0 and 1
- const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
-
- const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
-
- __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
- __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
- bx_0 = _mm_sub_epi8(bx_0, off);
- const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
-
- __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
- __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
- bx_1 = _mm_sub_epi8(bx_1, off);
- const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
-
- _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
- _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
-
- // Compute combined scale for the block 2 and 3
- const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
-
- const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
-
- __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
- __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
- bx_2 = _mm_sub_epi8(bx_2, off);
- const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
-
- __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
- __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
- bx_3 = _mm_sub_epi8(bx_3, off);
- const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
-
- // Convert int32_t to float
- __m128 p0 = _mm_cvtepi32_ps(i32_0);
- __m128 p1 = _mm_cvtepi32_ps(i32_1);
- __m128 p2 = _mm_cvtepi32_ps(i32_2);
- __m128 p3 = _mm_cvtepi32_ps(i32_3);
-
- // Apply the scale
- acc_0 = _mm_mul_ps( d_0_1, p0 );
- acc_1 = _mm_mul_ps( d_0_1, p1 );
- acc_2 = _mm_mul_ps( d_2_3, p2 );
- acc_3 = _mm_mul_ps( d_2_3, p3 );
- }
-
- // Main loop
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
- for (int i = 2; i < nb; i+=2) {
- _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
- _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
-
- // Compute combined scale for the block 0 and 1
- const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
-
- const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
-
- __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
- __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
- bx_0 = _mm_sub_epi8(bx_0, off);
- const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
-
- __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
- __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
- bx_1 = _mm_sub_epi8(bx_1, off);
- const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
-
- _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
- _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
-
- // Compute combined scale for the block 2 and 3
- const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
-
- const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
-
- __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
- __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
- bx_2 = _mm_sub_epi8(bx_2, off);
- const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
-
- __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
- __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
- bx_3 = _mm_sub_epi8(bx_3, off);
- const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
-
- // Convert int32_t to float
- __m128 p0 = _mm_cvtepi32_ps(i32_0);
- __m128 p1 = _mm_cvtepi32_ps(i32_1);
- __m128 p2 = _mm_cvtepi32_ps(i32_2);
- __m128 p3 = _mm_cvtepi32_ps(i32_3);
-
- // Apply the scale
- __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
- __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
- __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
- __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
-
- // Acummulate
- acc_0 = _mm_add_ps(p0_d, acc_0);
- acc_1 = _mm_add_ps(p1_d, acc_1);
- acc_2 = _mm_add_ps(p2_d, acc_2);
- acc_3 = _mm_add_ps(p3_d, acc_3);
- }
-
- *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
-#elif defined(__riscv_v_intrinsic)
- float sumf = 0.0;
-
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
-
- for (int i = 0; i < nb; i++) {
- // load elements
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
-
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
-
- // mask and store lower part of x, and then upper part
- vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
- vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
-
- vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
- vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
-
- // subtract offset
- vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
- vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
-
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
-
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
-
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
-
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
-
- sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
- }
-
- *s = sumf;
-#else
- // scalar
- float sumf = 0.0;
-
- for (int i = 0; i < nb; i++) {
- int sumi = 0;
-
- for (int j = 0; j < qk/2; ++j) {
- const int v0 = (x[i].qs[j] & 0x0F) - 8;
- const int v1 = (x[i].qs[j] >> 4) - 8;
-
- sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
- }
-
- sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
- }
-
- *s = sumf;
-#endif
-}
-
-static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
- const int qk = QK8_1;
- const int nb = n / qk;
-
- assert(n % qk == 0);
-
- const block_q4_1 * restrict x = vx;
- const block_q8_1 * restrict y = vy;
-
- // TODO: add WASM SIMD
-#if defined(__ARM_NEON)
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
-
- float summs = 0;
-
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
- for (int i = 0; i < nb; i += 2) {
- const block_q4_1 * restrict x0 = &x[i + 0];
- const block_q4_1 * restrict x1 = &x[i + 1];
- const block_q8_1 * restrict y0 = &y[i + 0];
- const block_q8_1 * restrict y1 = &y[i + 1];
-
- summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
-
- const uint8x16_t m4b = vdupq_n_u8(0x0F);
-
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
-
- // 4-bit -> 8-bit
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
-
- // load y
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
-
-#if defined(__ARM_FEATURE_DOTPROD)
- // dot product into int32x4_t
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
-
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
-#else
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
-
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
-
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
-
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
-#endif
- }
-
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
-#elif defined(__AVX2__) || defined(__AVX__)
- // Initialize accumulator with zeros
- __m256 acc = _mm256_setzero_ps();
-
- float summs = 0;
-
- // Main loop
- for (int i = 0; i < nb; ++i) {
- const float d0 = GGML_FP16_TO_FP32(x[i].d);
- const float d1 = y[i].d;
-
- summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
-
- const __m256 d0v = _mm256_set1_ps( d0 );
- const __m256 d1v = _mm256_set1_ps( d1 );
-
- // Compute combined scales
- const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
-
- // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
- const __m256i bx = bytes_from_nibbles_32(x[i].qs);
- const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
-
- const __m256 xy = mul_sum_us8_pairs_float(bx, by);
-
- // Accumulate d0*d1*x*y
-#if defined(__AVX2__)
- acc = _mm256_fmadd_ps( d0d1, xy, acc );
-#else
- acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
-#endif
- }
-
- *s = hsum_float_8(acc) + summs;
-#elif defined(__riscv_v_intrinsic)
- float sumf = 0.0;
-
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
-
- for (int i = 0; i < nb; i++) {
- // load elements
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
-
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
-
- // mask and store lower part of x, and then upper part
- vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
- vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
-
- vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
- vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
-
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
-
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
-
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
-
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
-
- sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
- }
-
- *s = sumf;
-#else
- // scalar
- float sumf = 0.0;
-
- for (int i = 0; i < nb; i++) {
- int sumi = 0;
-
- for (int j = 0; j < qk/2; ++j) {
- const int v0 = (x[i].qs[j] & 0x0F);
- const int v1 = (x[i].qs[j] >> 4);
-
- sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
- }
-
- sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
- }
-
- *s = sumf;
-#endif
-}
-
-static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
- const int qk = QK8_0;
- const int nb = n / qk;
-
- assert(n % qk == 0);
- assert(qk == QK5_0);
-
- const block_q5_0 * restrict x = vx;
- const block_q8_0 * restrict y = vy;
-
-#if defined(__ARM_NEON)
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
-
- uint32_t qh0;
- uint32_t qh1;
-
- uint64_t tmp0[4];
- uint64_t tmp1[4];
-
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
- for (int i = 0; i < nb; i += 2) {
- const block_q5_0 * restrict x0 = &x[i];
- const block_q5_0 * restrict x1 = &x[i + 1];
- const block_q8_0 * restrict y0 = &y[i];
- const block_q8_0 * restrict y1 = &y[i + 1];
-
- const uint8x16_t m4b = vdupq_n_u8(0x0F);
-
- // extract the 5th bit via lookup table ((!b) << 4)
- memcpy(&qh0, x0->qh, sizeof(qh0));
- memcpy(&qh1, x1->qh, sizeof(qh1));
-
- tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
- tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
- tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
- tmp0[3] = table_b2b_1[(qh0 >> 24) ];
-
- tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
- tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
- tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
- tmp1[3] = table_b2b_1[(qh1 >> 24) ];
-
- const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
- const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
- const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
- const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
-
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
-
- // 4-bit -> 8-bit
- int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
- int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
- int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
- int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
-
- // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
- const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
- const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
- const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
- const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
-
- // load y
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
-
-#if defined(__ARM_FEATURE_DOTPROD)
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
- vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
- vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
- vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
- vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
-#else
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
-
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
-
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
-
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
-#endif
- }
-
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
-#elif defined(__wasm_simd128__)
- v128_t sumv = wasm_f32x4_splat(0.0f);
-
- uint32_t qh;
- uint64_t tmp[4];
-
- // TODO: check if unrolling this is better
- for (int i = 0; i < nb; ++i) {
- const block_q5_0 * restrict x0 = &x[i];
- const block_q8_0 * restrict y0 = &y[i];
-
- const v128_t m4b = wasm_i8x16_splat(0x0F);
-
- // extract the 5th bit
- memcpy(&qh, x0->qh, sizeof(qh));
-
- tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
- tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
- tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
- tmp[3] = table_b2b_1[(qh >> 24) ];
-
- const v128_t qhl = wasm_v128_load(tmp + 0);
- const v128_t qhh = wasm_v128_load(tmp + 2);
-
- const v128_t v0 = wasm_v128_load(x0->qs);
-
- // 4-bit -> 8-bit
- const v128_t v0l = wasm_v128_and (v0, m4b);
- const v128_t v0h = wasm_u8x16_shr(v0, 4);
-
- // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
- const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
- const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
-
- // load y
- const v128_t v1l = wasm_v128_load(y0->qs);
- const v128_t v1h = wasm_v128_load(y0->qs + 16);
-
- // int8x16 -> int16x8
- const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
- const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
- const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
- const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
-
- const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
- const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
- const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
- const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
-
- // dot product
- sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
- wasm_i32x4_add(
- wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
- wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
- wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
- wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
- wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
- }
-
- *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
- wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
-#elif defined(__AVX2__)
- // Initialize accumulator with zeros
- __m256 acc = _mm256_setzero_ps();
-
- // Main loop
- for (int i = 0; i < nb; i++) {
- /* Compute combined scale for the block */
- const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
-
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
- __m256i bxhi = bytes_from_bits_32(x[i].qh);
- bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
- bx = _mm256_or_si256(bx, bxhi);
-
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
-
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
-
- /* Multiply q with scale and accumulate */
- acc = _mm256_fmadd_ps(d, q, acc);
- }
-
- *s = hsum_float_8(acc);
-#elif defined(__AVX__)
- // Initialize accumulator with zeros
- __m256 acc = _mm256_setzero_ps();
- __m128i mask = _mm_set1_epi8((char)0xF0);
-
- // Main loop
- for (int i = 0; i < nb; i++) {
- /* Compute combined scale for the block */
- const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
-
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
- const __m256i bxhi = bytes_from_bits_32(x[i].qh);
- __m128i bxhil = _mm256_castsi256_si128(bxhi);
- __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
- bxhil = _mm_andnot_si128(bxhil, mask);
- bxhih = _mm_andnot_si128(bxhih, mask);
- __m128i bxl = _mm256_castsi256_si128(bx);
- __m128i bxh = _mm256_extractf128_si256(bx, 1);
- bxl = _mm_or_si128(bxl, bxhil);
- bxh = _mm_or_si128(bxh, bxhih);
- bx = MM256_SET_M128I(bxh, bxl);
-
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
-
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
-
- /* Multiply q with scale and accumulate */
- acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
- }
-
- *s = hsum_float_8(acc);
-#elif defined(__riscv_v_intrinsic)
- float sumf = 0.0;
-
- uint32_t qh;
-
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
-
- // These tempory registers are for masking and shift operations
- vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
- vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
-
- vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
- vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
-
- for (int i = 0; i < nb; i++) {
- memcpy(&qh, x[i].qh, sizeof(uint32_t));
-
- // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
- vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
- vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
- vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
-
- // ((qh & (1u << (j + 16))) >> (j + 12));
- vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
- vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
-
- // narrowing
- vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
- vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
-
- vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
- vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
-
- // load
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
-
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
-
- vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
- vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
-
- vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
- vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
-
- vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
- vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
-
- vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
- vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
-
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
-
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
-
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
-
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
-
- sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
- }
-
- *s = sumf;
-#else
- // scalar
- float sumf = 0.0;
-
- for (int i = 0; i < nb; i++) {
- uint32_t qh;
- memcpy(&qh, x[i].qh, sizeof(qh));
-
- int sumi = 0;
-
- for (int j = 0; j < qk/2; ++j) {
- const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
- const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
-
- const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
- const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
-
- sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
- }
-
- sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
- }
-
- *s = sumf;
-#endif
-}
-
-static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
- const int qk = QK8_1;
- const int nb = n / qk;
-
- assert(n % qk == 0);
- assert(qk == QK5_1);
-
- const block_q5_1 * restrict x = vx;
- const block_q8_1 * restrict y = vy;
-
-#if defined(__ARM_NEON)
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
-
- float summs0 = 0.0f;
- float summs1 = 0.0f;
-
- uint32_t qh0;
- uint32_t qh1;
-
- uint64_t tmp0[4];
- uint64_t tmp1[4];
-
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
- for (int i = 0; i < nb; i += 2) {
- const block_q5_1 * restrict x0 = &x[i];
- const block_q5_1 * restrict x1 = &x[i + 1];
- const block_q8_1 * restrict y0 = &y[i];
- const block_q8_1 * restrict y1 = &y[i + 1];
-
- const uint8x16_t m4b = vdupq_n_u8(0x0F);
-
- summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s;
- summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s;
-
- // extract the 5th bit via lookup table ((b) << 4)
- memcpy(&qh0, x0->qh, sizeof(qh0));
- memcpy(&qh1, x1->qh, sizeof(qh1));
-
- tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
- tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
- tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
- tmp0[3] = table_b2b_0[(qh0 >> 24) ];
-
- tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
- tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
- tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
- tmp1[3] = table_b2b_0[(qh1 >> 24) ];
-
- const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
- const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
- const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
- const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
-
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
-
- // 4-bit -> 8-bit
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
-
- // add high bit
- const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
- const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
- const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
- const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
-
- // load y
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
-
-#if defined(__ARM_FEATURE_DOTPROD)
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
- vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
- vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
- vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
- vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
-#else
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
-
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
-
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
-
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
-#endif
- }
-
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
-#elif defined(__wasm_simd128__)
- v128_t sumv = wasm_f32x4_splat(0.0f);
-
- float summs = 0.0f;
-
- uint32_t qh;
- uint64_t tmp[4];
-
- // TODO: check if unrolling this is better
- for (int i = 0; i < nb; ++i) {
- const block_q5_1 * restrict x0 = &x[i];
- const block_q8_1 * restrict y0 = &y[i];
-
- summs += GGML_FP16_TO_FP32(x0->m) * y0->s;
-
- const v128_t m4b = wasm_i8x16_splat(0x0F);
-
- // extract the 5th bit
- memcpy(&qh, x0->qh, sizeof(qh));
-
- tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
- tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
- tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
- tmp[3] = table_b2b_0[(qh >> 24) ];
-
- const v128_t qhl = wasm_v128_load(tmp + 0);
- const v128_t qhh = wasm_v128_load(tmp + 2);
-
- const v128_t v0 = wasm_v128_load(x0->qs);
-
- // 4-bit -> 8-bit
- const v128_t v0l = wasm_v128_and (v0, m4b);
- const v128_t v0h = wasm_u8x16_shr(v0, 4);
-
- // add high bit
- const v128_t v0lf = wasm_v128_or(v0l, qhl);
- const v128_t v0hf = wasm_v128_or(v0h, qhh);
-
- // load y
- const v128_t v1l = wasm_v128_load(y0->qs);
- const v128_t v1h = wasm_v128_load(y0->qs + 16);
-
- // int8x16 -> int16x8
- const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
- const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
- const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
- const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
-
- const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
- const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
- const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
- const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
-
- // dot product
- sumv = wasm_f32x4_add(sumv,
- wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
- wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
- wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
- wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
- wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
- wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d)));
- }
-
- *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
- wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
-#elif defined(__AVX2__)
- // Initialize accumulator with zeros
- __m256 acc = _mm256_setzero_ps();
-
- float summs = 0.0f;
-
- // Main loop
- for (int i = 0; i < nb; i++) {
- const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
-
- summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
-
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
- __m256i bxhi = bytes_from_bits_32(x[i].qh);
- bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
- bx = _mm256_or_si256(bx, bxhi);
-
- const __m256 dy = _mm256_set1_ps(y[i].d);
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
-
- const __m256 q = mul_sum_us8_pairs_float(bx, by);
-
- acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
- }
-
- *s = hsum_float_8(acc) + summs;
-#elif defined(__AVX__)
- // Initialize accumulator with zeros
- __m256 acc = _mm256_setzero_ps();
- __m128i mask = _mm_set1_epi8(0x10);
-
- float summs = 0.0f;
-
- // Main loop
- for (int i = 0; i < nb; i++) {
- const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
-
- summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
-
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
- const __m256i bxhi = bytes_from_bits_32(x[i].qh);
- __m128i bxhil = _mm256_castsi256_si128(bxhi);
- __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
- bxhil = _mm_and_si128(bxhil, mask);
- bxhih = _mm_and_si128(bxhih, mask);
- __m128i bxl = _mm256_castsi256_si128(bx);
- __m128i bxh = _mm256_extractf128_si256(bx, 1);
- bxl = _mm_or_si128(bxl, bxhil);
- bxh = _mm_or_si128(bxh, bxhih);
- bx = MM256_SET_M128I(bxh, bxl);
-
- const __m256 dy = _mm256_set1_ps(y[i].d);
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
-
- const __m256 q = mul_sum_us8_pairs_float(bx, by);
-
- acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
- }
-
- *s = hsum_float_8(acc) + summs;
-#elif defined(__riscv_v_intrinsic)
- float sumf = 0.0;
-
- uint32_t qh;
-
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
-
- // temporary registers for shift operations
- vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
- vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
-
- for (int i = 0; i < nb; i++) {
- memcpy(&qh, x[i].qh, sizeof(uint32_t));
-
- // load qh
- vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
-
- // ((qh >> (j + 0)) << 4) & 0x10;
- vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
- vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
- vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
-
- // ((qh >> (j + 12)) ) & 0x10;
- vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
- vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
-
- // narrowing
- vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
- vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
-
- vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
- vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
-
- // load
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
-
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
-
- vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
- vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
-
- vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
- vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
-
- vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
- vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
-
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
-
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
-
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
-
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
-
- sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
- }
-
- *s = sumf;
-#else
- // scalar
- float sumf = 0.0;
-
- for (int i = 0; i < nb; i++) {
- uint32_t qh;
- memcpy(&qh, x[i].qh, sizeof(qh));
-
- int sumi = 0;
-
- for (int j = 0; j < qk/2; ++j) {
- const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
- const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
-
- const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
- const int32_t x1 = (x[i].qs[j] >> 4) | xh_1;
-
- sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
- }
-
- sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
- }
-
- *s = sumf;
-#endif
-}
-
-static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
- const int qk = QK8_0;
- const int nb = n / qk;
-
- assert(n % qk == 0);
-
- const block_q8_0 * restrict x = vx;
- const block_q8_0 * restrict y = vy;
-
-#if defined(__ARM_NEON)
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
-
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
- for (int i = 0; i < nb; i += 2) {
- const block_q8_0 * restrict x0 = &x[i + 0];
- const block_q8_0 * restrict x1 = &x[i + 1];
- const block_q8_0 * restrict y0 = &y[i + 0];
- const block_q8_0 * restrict y1 = &y[i + 1];
-
- const int8x16_t x0_0 = vld1q_s8(x0->qs);
- const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
- const int8x16_t x1_0 = vld1q_s8(x1->qs);
- const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
-
- // load y
- const int8x16_t y0_0 = vld1q_s8(y0->qs);
- const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
- const int8x16_t y1_0 = vld1q_s8(y1->qs);
- const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
-
-#if defined(__ARM_FEATURE_DOTPROD)
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
- vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
- vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
-
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
- vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
- vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
-
-#else
- const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
- const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
- const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
- const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
-
- const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
- const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
- const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
- const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
-
- const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
- const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
- const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
- const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
-
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
-#endif
- }
-
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
-#elif defined(__AVX2__) || defined(__AVX__)
- // Initialize accumulator with zeros
- __m256 acc = _mm256_setzero_ps();
-
- // Main loop
- for (int i = 0; i < nb; ++i) {
- // Compute combined scale for the block
- const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
- __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
-
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
-
- // Multiply q with scale and accumulate
-#if defined(__AVX2__)
- acc = _mm256_fmadd_ps( d, q, acc );
-#else
- acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
-#endif
- }
-
- *s = hsum_float_8(acc);
-#elif defined(__riscv_v_intrinsic)
- float sumf = 0.0;
- size_t vl = __riscv_vsetvl_e8m1(qk);
-
- for (int i = 0; i < nb; i++) {
- // load elements
- vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
- vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
-
- vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
-
- vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
- vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
-
- int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
-
- sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
- }
-
- *s = sumf;
-#else
- // scalar
- float sumf = 0.0;
-
- for (int i = 0; i < nb; i++) {
- int sumi = 0;
-
- for (int j = 0; j < qk; j++) {
- sumi += x[i].qs[j]*y[i].qs[j];
- }
-
- sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
- }
-
- *s = sumf;
-#endif
-}
-
// compute GGML_VEC_DOT_UNROLL dot products at once
// xs - x row stride in bytes
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@@ -21001,7 +18706,6 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
result = ggml_quantize_q8_0(src + start, block, n, n, hist);
} break;
-#ifdef GGML_USE_K_QUANTS
case GGML_TYPE_Q2_K:
{
GGML_ASSERT(start % QK_K == 0);
@@ -21032,7 +18736,6 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
block_q6_K * block = (block_q6_K*)dst + start / QK_K;
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
} break;
-#endif
case GGML_TYPE_F16:
{
int elemsize = sizeof(ggml_fp16_t);
diff --git a/ggml.h b/ggml.h
index 08bff551..8c954904 100644
--- a/ggml.h
+++ b/ggml.h
@@ -1930,12 +1930,19 @@ extern "C" {
// quantization
//
+ // TODO: these would probably get removed in favor of the more general ggml_quantize_chunk
GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
+ GGML_API size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
+ GGML_API size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
+ GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
+ GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
+ GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
+
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
//
diff --git a/llama.cpp b/llama.cpp
index 3d431ee7..1d1db8fc 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -19,13 +19,11 @@
#ifdef GGML_USE_MPI
# include "ggml-mpi.h"
#endif
-#ifdef GGML_USE_K_QUANTS
-# ifndef QK_K
-# ifdef GGML_QKK_64
-# define QK_K 64
-# else
-# define QK_K 256
-# endif
+#ifndef QK_K
+# ifdef GGML_QKK_64
+# define QK_K 64
+# else
+# define QK_K 256
# endif
#endif
@@ -8052,7 +8050,7 @@ struct no_init {
struct quantize_state_internal {
const llama_model & model;
const llama_model_quantize_params * params;
-#ifdef GGML_USE_K_QUANTS
+
int n_attention_wv = 0;
int n_feed_forward_w2 = 0;
int i_attention_wv = 0;
@@ -8060,7 +8058,7 @@ struct quantize_state_internal {
int n_k_quantized = 0;
int n_fallback = 0;
-#endif
+
quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params)
: model(model)
, params(params)
@@ -8125,7 +8123,6 @@ static void llama_convert_tensor_internal(
workers.clear();
}
-#ifdef GGML_USE_K_QUANTS
static ggml_type get_k_quant_type(
quantize_state_internal & qs,
ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype
@@ -8237,7 +8234,6 @@ static ggml_type get_k_quant_type(
return new_type;
}
-#endif
static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
ggml_type quantized_type;
@@ -8252,7 +8248,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_F16: quantized_type = GGML_TYPE_F16; break;
case LLAMA_FTYPE_ALL_F32: quantized_type = GGML_TYPE_F32; break;
-#ifdef GGML_USE_K_QUANTS
// K-quants
case LLAMA_FTYPE_MOSTLY_Q2_K: quantized_type = GGML_TYPE_Q2_K; break;
case LLAMA_FTYPE_MOSTLY_Q3_K_S:
@@ -8263,7 +8258,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q5_K_S:
case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break;
case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break;
-#endif
+
default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
}
@@ -8304,7 +8299,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION);
gguf_set_val_u32(ctx_out, "general.file_type", ftype);
-#ifdef GGML_USE_K_QUANTS
for (int i = 0; i < ml.n_tensors; ++i) {
struct ggml_tensor * meta = ml.get_tensor_meta(i);
@@ -8322,7 +8316,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
LLAMA_LOG_WARN("%s ============ Strange model: n_attention_wv = %d, n_feed_forward_w2 = %d, hparams.n_layer = %d\n",
__func__, qs.n_attention_wv, qs.n_feed_forward_w2, model.hparams.n_layer);
}
-#endif
size_t total_size_org = 0;
size_t total_size_new = 0;
@@ -8387,9 +8380,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (quantize) {
new_type = quantized_type;
-#ifdef GGML_USE_K_QUANTS
- new_type = get_k_quant_type(qs, new_type, tensor, ftype);
-#endif
+ if (!params->pure) {
+ new_type = get_k_quant_type(qs, new_type, tensor, ftype);
+ }
+
// If we've decided to quantize to the same type the tensor is already
// in then there's nothing to do.
quantize = tensor->type != new_type;
@@ -8514,12 +8508,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
LLAMA_LOG_INFO("\n");
}
}
-#ifdef GGML_USE_K_QUANTS
+
if (qs.n_fallback > 0) {
LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) incompatible with k-quants and required fallback quantization\n",
__func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
}
-#endif
}
static int llama_apply_lora_from_file_internal(
@@ -8844,6 +8837,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
/*.allow_requantize =*/ false,
/*.quantize_output_tensor =*/ true,
/*.only_copy =*/ false,
+ /*.pure =*/ false,
};
return result;
diff --git a/llama.h b/llama.h
index d901dcd9..6927bd60 100644
--- a/llama.h
+++ b/llama.h
@@ -191,6 +191,7 @@ extern "C" {
bool allow_requantize; // allow quantizing non-f32/f16 tensors
bool quantize_output_tensor; // quantize output.weight
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
+ bool pure; // disable k-quant mixtures and quantize all tensors to the same type
} llama_model_quantize_params;
// grammar types