summaryrefslogtreecommitdiff
path: root/pocs
diff options
context:
space:
mode:
authorsnadampal <87143774+snadampal@users.noreply.github.com>2024-02-11 07:22:33 -0600
committerGitHub <noreply@github.com>2024-02-11 15:22:33 +0200
commita07d0fee1f05c5c1dc49948ae1a3293db017275f (patch)
tree06614ff1364269493e4853333ced56802abd7284 /pocs
parente4640d8fdf56f14a6db3d092bcd3d2d315cb5d04 (diff)
ggml : add mmla kernels for quantized GEMM (#4966)
* ggml: aarch64: implement smmla kernel for q8_0_q8_0 quantized gemm armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds mmla kernel for q8_0_q8_0 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel. * ggml: aarch64: implement smmla kernel for q4_0_q8_0 quantized gemm armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds mmla kernel for q4_0_q8_0 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel. * ggml: aarch64: implement smmla kernel for q4_1_q8_1 quantized gemm armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds mmla kernel for q4_1_q8_1 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel. * ggml: update unit tests for the new vec_dot interface * llama.cpp: add MATMUL_INT8 capability to system_info
Diffstat (limited to 'pocs')
-rw-r--r--pocs/vdot/q8dot.cpp4
-rw-r--r--pocs/vdot/vdot.cpp4
2 files changed, 4 insertions, 4 deletions
diff --git a/pocs/vdot/q8dot.cpp b/pocs/vdot/q8dot.cpp
index 111770d5..1a52ff5e 100644
--- a/pocs/vdot/q8dot.cpp
+++ b/pocs/vdot/q8dot.cpp
@@ -156,8 +156,8 @@ int main(int argc, char** argv) {
t1 = std::chrono::high_resolution_clock::now();
float fs;
- if (type == 0) funcs.vec_dot(kVecSize * QK4_1, &fs, x40.data(), y.data());
- else funcs.vec_dot(kVecSize * QK4_1, &fs, x41.data(), y.data());
+ if (type == 0) funcs.vec_dot(kVecSize * QK4_1, &fs, 0, x40.data(), 0, y.data(), 0, 1);
+ else funcs.vec_dot(kVecSize * QK4_1, &fs, 0, x41.data(), 0, y.data(), 0, 1);
t2 = std::chrono::high_resolution_clock::now();
t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
if (iloop > 3) ggml.addResult(fs, t);
diff --git a/pocs/vdot/vdot.cpp b/pocs/vdot/vdot.cpp
index 73ffcd1c..17e9e448 100644
--- a/pocs/vdot/vdot.cpp
+++ b/pocs/vdot/vdot.cpp
@@ -284,8 +284,8 @@ int main(int argc, char** argv) {
else {
auto vdot = ggml_internal_get_type_traits(funcs.vec_dot_type);
vdot.from_float(y1.data(), q8.data(), kVecSize);
- if (useQ4_1) funcs.vec_dot(kVecSize, &result, q41.data(), q8.data());
- else funcs.vec_dot(kVecSize, &result, q40.data(), q8.data());
+ if (useQ4_1) funcs.vec_dot(kVecSize, &result, 0, q41.data(), 0, q8.data(), 0, 1);
+ else funcs.vec_dot(kVecSize, &result, 0, q40.data(), 0, q8.data(), 0, 1);
}
sumq += result;
t2 = std::chrono::high_resolution_clock::now();