summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml-metal.m5
-rw-r--r--ggml-metal.metal1177
-rw-r--r--ggml.c1
-rw-r--r--tests/test-backend-ops.cpp1
4 files changed, 135 insertions, 1049 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 419d8b9e..38da384b 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -1926,7 +1926,12 @@ static enum ggml_status ggml_metal_graph_compute(
{
nth0 = 4;
nth1 = 16;
+ #if QK_K == 64
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
+ #else
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
+ #endif
+
} break;
default:
{
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 9a29f57a..3a823e65 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -857,15 +857,16 @@ void mul_vec_q_n_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- int64_t ne10,
- int64_t ne12,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values,
uint3 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0;
@@ -942,7 +943,7 @@ kernel void kernel_mul_mv_q4_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
@@ -968,7 +969,7 @@ kernel void kernel_mul_mv_q4_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
@@ -994,7 +995,7 @@ kernel void kernel_mul_mv_q5_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
@@ -1020,7 +1021,7 @@ kernel void kernel_mul_mv_q5_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
@@ -1039,6 +1040,7 @@ void kernel_mul_mv_q8_0_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1119,7 +1121,7 @@ kernel void kernel_mul_mv_q8_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
}
#define N_F32_F32 4
@@ -2709,6 +2711,7 @@ void kernel_mul_mv_q2_K_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2871,7 +2874,7 @@ kernel void kernel_mul_mv_q2_K_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
#if QK_K == 256
@@ -2888,6 +2891,7 @@ void kernel_mul_mv_q3_K_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3046,6 +3050,7 @@ void kernel_mul_mv_q3_K_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3135,7 +3140,7 @@ kernel void kernel_mul_mv_q3_K_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
#if QK_K == 256
@@ -3152,6 +3157,7 @@ void kernel_mul_mv_q4_K_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3265,6 +3271,7 @@ void kernel_mul_mv_q4_K_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3373,7 +3380,7 @@ kernel void kernel_mul_mv_q4_K_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_q5_K_f32_impl(
@@ -3389,6 +3396,7 @@ void kernel_mul_mv_q5_K_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3579,7 +3587,7 @@ kernel void kernel_mul_mv_q5_K_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
void kernel_mul_mv_q6_K_f32_impl(
@@ -3595,6 +3603,7 @@ void kernel_mul_mv_q6_K_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3713,7 +3722,7 @@ kernel void kernel_mul_mv_q6_K_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
// ======================= "True" 2-bit
@@ -4396,6 +4405,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -4485,6 +4495,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -4593,11 +4604,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
- threadgroup float * shared_values [[threadgroup(0)]],
+ threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
const int nb = ne00/QK4_NL;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
@@ -4687,11 +4699,11 @@ void kernel_mul_mv_iq4_xs_f32_impl(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
- threadgroup float * shared_values [[threadgroup(0)]],
+ threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
+ threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
@@ -4794,7 +4806,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq1_m_f32")]]
@@ -4822,7 +4834,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
@@ -4846,7 +4858,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
- threadgroup float * shared_values [[threadgroup(0)]],
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -4875,7 +4887,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
- threadgroup float * shared_values [[threadgroup(0)]],
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -6022,685 +6034,52 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
// matrix-vector multiplication
//
-[[host_name("kernel_mul_mv_id_f32_f32")]]
-kernel void kernel_mul_mv_id_f32_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_f32_f32_impl(
- src0,
- src1 + bid*nb11,
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- nb00,
- nb01,
- nb02,
- ne10,
- ne11,
- ne12,
- nb10,
- nb11,
- nb12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg);
-}
-
-[[host_name("kernel_mul_mv_id_f16_f32")]]
-kernel void kernel_mul_mv_id_f16_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_f16_f32_impl(
- src0,
- src1 + bid*nb11,
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- nb00,
- nb01,
- nb02,
- ne10,
- ne11,
- ne12,
- nb10,
- nb11,
- nb12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg);
-}
-
-[[host_name("kernel_mul_mv_id_q8_0_f32")]]
-kernel void kernel_mul_mv_id_q8_0_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q8_0_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q4_0_f32")]]
-kernel void kernel_mul_mv_id_q4_0_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q4_1_f32")]]
-kernel void kernel_mul_mv_id_q4_1_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q5_0_f32")]]
-kernel void kernel_mul_mv_id_q5_0_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q5_1_f32")]]
-kernel void kernel_mul_mv_id_q5_1_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q2_K_f32")]]
-kernel void kernel_mul_mv_id_q2_K_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q2_K_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q3_K_f32")]]
-kernel void kernel_mul_mv_id_q3_K_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q3_K_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q4_K_f32")]]
-kernel void kernel_mul_mv_id_q4_K_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q4_K_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q5_K_f32")]]
-kernel void kernel_mul_mv_id_q5_K_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_q5_K_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q6_K_f32")]]
-kernel void kernel_mul_mv_id_q6_K_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
+typedef void (kernel_mul_mv_impl_t)(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]);
- kernel_mul_mv_q6_K_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
+typedef void (kernel_mul_mv2_impl_t)(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]);
-[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
-kernel void kernel_mul_mv_id_iq2_xxs_f32(
- device const char * src0s,
+template<kernel_mul_mv_impl_t impl_fn>
+void mmv_fn(
+ device const char * src0,
device const char * src1,
device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -6719,45 +6098,19 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
- constant int & idx,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq2_xxs_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
+ impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
}
-[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
-kernel void kernel_mul_mv_id_iq2_xs_f32(
- device const char * src0s,
+template<kernel_mul_mv2_impl_t impl_fn>
+void mmv_fn(
+ device const char * src0,
device const char * src1,
device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -6776,45 +6129,18 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
- constant int & idx,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq2_xs_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
+ impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
}
-[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
-kernel void kernel_mul_mv_id_iq3_xxs_f32(
- device const char * src0s,
+typedef void (mul_mv_impl_fn_t)(
+ device const char * src0,
device const char * src1,
device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -6833,40 +6159,14 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
- constant int & idx,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq3_xxs_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
+ uint sgitg[[simdgroup_index_in_threadgroup]]);
-[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
-kernel void kernel_mul_mv_id_iq3_s_f32(
+template<mul_mv_impl_fn_t impl_fn>
+kernel void kernel_mul_mv_id(
device const char * src0s,
device const char * src1,
device float * dst,
@@ -6903,27 +6203,36 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
device const char * src0 = src0s + id*nb02;
- kernel_mul_mv_iq3_s_f32_impl(
+ impl_fn(
src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
+ src1 + bid*nb11,
+ dst + bid*ne0,
ne00,
ne01,
ne02,
+ nb00,
+ nb01,
+ nb02,
ne10,
+ ne11,
ne12,
+ ne13,
+ nb10,
+ nb11,
+ nb12,
ne0,
ne1,
+ nb1,
r2,
r3,
shared_values,
tgpig,
+ tiitg,
tiisg,
sgitg);
}
-[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
-kernel void kernel_mul_mv_id_iq2_s_f32(
+typedef void (kernel_mul_mv_id_t)(
device const char * src0s,
device const char * src1,
device float * dst,
@@ -6952,257 +6261,29 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq2_s_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
-kernel void kernel_mul_mv_id_iq1_s_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq1_s_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
-kernel void kernel_mul_mv_id_iq1_m_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq1_m_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
-kernel void kernel_mul_mv_id_iq4_nl_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup float * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
- kernel_mul_mv_iq4_nl_f32_impl(
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
-kernel void kernel_mul_mv_id_iq4_xs_f32(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup float * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
-
- tgpig.z = tgpig.z%(ne12*ne13);
-
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
-
-#if QK_K == 64
- kernel_mul_mv_iq4_nl_f32_impl(
-#else
- kernel_mul_mv_iq4_xs_f32_impl(
+ uint sgitg[[simdgroup_index_in_threadgroup]]);
+
+template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
+#if QK_K != 64
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
#endif
- src0,
- (device const float *) (src1 + bid*nb11),
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- ne10,
- ne12,
- ne0,
- ne1,
- r2,
- r3,
- shared_values,
- tgpig,
- tiisg,
- sgitg);
-}
+
diff --git a/ggml.c b/ggml.c
index 90af4342..14288d29 100644
--- a/ggml.c
+++ b/ggml.c
@@ -11012,7 +11012,6 @@ static void ggml_compute_forward_mul_mat_id(
}
// initialize matrix_row_counts
- GGML_ASSERT(wdata == wdata_src1_end);
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
// group rows by src0 matrix
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 51b3487b..b5067595 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -2014,6 +2014,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (int n_mats : {2, 4, 8}) {
for (int id = 0; id < n_mats; id++) {
for (bool v : {false, true}) {
+ test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 1, 256, v));
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v));
}
}