diff options
-rw-r--r-- | ggml/src/ggml-metal.m | 90 | ||||
-rw-r--r-- | ggml/src/ggml-metal.metal | 562 |
2 files changed, 649 insertions, 3 deletions
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 13d7b97b..4cd44cb9 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -115,6 +115,9 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KT, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_KT, + //GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KT, GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, GGML_METAL_KERNEL_TYPE_RMS_NORM, GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM, @@ -158,6 +161,9 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ5_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_KT_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_KT_F32, + //GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KT_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, @@ -195,6 +201,9 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KT_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_KT_F32, + //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KT_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, @@ -229,6 +238,9 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KT_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KT_F32, + //GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KT_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F16, @@ -263,6 +275,9 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KT_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KT_F16, + //GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KT_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, @@ -297,6 +312,9 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KT_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_KT_F32, + //GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KT_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, @@ -747,6 +765,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K, get_rows_iq4_k, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K, get_rows_iq5_k, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K, get_rows_iq6_k, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KT, get_rows_iq2_kt, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_KT, get_rows_iq3_kt, true); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KT, get_rows_iq4_kt, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM, fused_rms_norm, ctx->support_simdgroup_reduction); @@ -790,6 +811,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_K_F32, mul_mv_iq4_k_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ5_K_F32, mul_mv_iq5_k_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ6_K_F32, mul_mv_iq6_k_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_KT_F32, mul_mv_iq2_kt_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_KT_F32, mul_mv_iq3_kt_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KT_F32, mul_mv_iq4_kt_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); @@ -827,6 +851,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32, mul_mv_id_iq4_k_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32, mul_mv_id_iq5_k_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32, mul_mv_id_iq6_k_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KT_F32, mul_mv_id_iq2_kt_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_KT_F32, mul_mv_id_iq3_kt_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KT_F32, mul_mv_id_iq4_kt_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, ctx->support_simdgroup_mm); @@ -861,6 +888,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32, mul_mm_iq4_k_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32, mul_mm_iq5_k_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32, mul_mm_iq6_k_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KT_F32, mul_mm_iq2_kt_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KT_F32, mul_mm_iq3_kt_f32, ctx->support_simdgroup_mm); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KT_F32, mul_mm_iq4_kt_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F16, mul_mm_f32_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F16, mul_mm_f16_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F16, mul_mm_bf16_f16, ctx->support_simdgroup_mm); @@ -895,6 +925,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16, mul_mm_iq4_k_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16, mul_mm_iq5_k_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F16, mul_mm_iq6_k_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KT_F16, mul_mm_iq2_kt_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KT_F16, mul_mm_iq3_kt_f16, ctx->support_simdgroup_mm); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KT_F16, mul_mm_iq4_kt_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, ctx->support_simdgroup_mm); @@ -929,6 +962,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32, mul_mm_id_iq4_k_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32, mul_mm_id_iq5_k_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32, mul_mm_id_iq6_k_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KT_F32, mul_mm_id_iq2_kt_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_KT_F32, mul_mm_id_iq3_kt_f32, ctx->support_simdgroup_mm); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KT_F32, mul_mm_id_iq4_kt_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); @@ -2142,6 +2178,9 @@ static void ggml_metal_encode_node( case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32 ].pipeline; break; case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32 ].pipeline; break; case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_KT: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KT_F32 ].pipeline; break; + case GGML_TYPE_IQ3_KT: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KT_F32 ].pipeline; break; + //case GGML_TYPE_IQ4_KT: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KT_F32 ].pipeline; break; default: GGML_ABORT("MUL MAT-MAT not implemented"); } } @@ -2181,6 +2220,9 @@ static void ggml_metal_encode_node( case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16 ].pipeline; break; case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16 ].pipeline; break; case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F16 ].pipeline; break; + case GGML_TYPE_IQ2_KT: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KT_F16 ].pipeline; break; + case GGML_TYPE_IQ3_KT: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KT_F16 ].pipeline; break; + //case GGML_TYPE_IQ4_KT: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KT_F16 ].pipeline; break; default: GGML_ABORT("MUL MAT-MAT not implemented"); } } @@ -2440,6 +2482,24 @@ static void ggml_metal_encode_node( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ6_K_F32].pipeline; } break; + case GGML_TYPE_IQ2_KT: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_KT_F32].pipeline; + } break; + case GGML_TYPE_IQ3_KT: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_KT_F32].pipeline; + } break; + //case GGML_TYPE_IQ4_KT: + // { + // nth0 = 4; + // nth1 = 16; + // pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KT_F32].pipeline; + // } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); @@ -2471,7 +2531,8 @@ static void ggml_metal_encode_node( if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| - src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0) { + src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0 || + src0t == GGML_TYPE_IQ2_KT|| src0t == GGML_TYPE_IQ3_KT) { //|| src0t == GGML_TYPE_IQ4_KT) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { @@ -2596,6 +2657,9 @@ static void ggml_metal_encode_node( case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32 ].pipeline; break; case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32 ].pipeline; break; case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_KT: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KT_F32 ].pipeline; break; + case GGML_TYPE_IQ3_KT: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_KT_F32 ].pipeline; break; + //case GGML_TYPE_IQ4_KT: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KT_F32 ].pipeline; break; default: GGML_ABORT("MUL_MAT_ID not implemented"); } @@ -2839,6 +2903,24 @@ static void ggml_metal_encode_node( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32].pipeline; } break; + case GGML_TYPE_IQ2_KT: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KT_F32].pipeline; + } break; + case GGML_TYPE_IQ3_KT: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_KT_F32].pipeline; + } break; + //case GGML_TYPE_IQ4_KT: + // { + // nth0 = 4; + // nth1 = 16; + // pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KT_F32].pipeline; + // } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); @@ -2881,7 +2963,8 @@ static void ggml_metal_encode_node( if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_Q6_0 || - src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K) { + src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K|| + src0t == GGML_TYPE_IQ2_KT|| src0t == GGML_TYPE_IQ3_KT) { //|| src0t == GGML_TYPE_IQ4_KT) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { @@ -2962,6 +3045,9 @@ static void ggml_metal_encode_node( case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K ].pipeline; break; case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K ].pipeline; break; case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K ].pipeline; break; + case GGML_TYPE_IQ2_KT: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KT ].pipeline; break; + case GGML_TYPE_IQ3_KT: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_KT ].pipeline; break; + //case GGML_TYPE_IQ4_KT: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KT ].pipeline; break; case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; default: GGML_ABORT("not implemented"); } diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index b792844d..a05a890e 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2596,7 +2596,7 @@ void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) { template <typename type4x4> void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { device const int8_t * qs = ((device const int8_t *)xb->qs); - if constexpr (is_same_v<type4x4, half4x4>) { + if (is_same_v<type4x4, half4x4>) { const half d = xb->d; for (int i = 0; i < 16; i++) { reg[i/4][i%4] = (half)qs[i + 16*il] * d; @@ -6596,6 +6596,431 @@ void kernel_mul_mv_iq2_k_f32_impl( } } +struct Trellis { + constexpr constant static uint32_t kmask1 = 0x8fff8fff; + constexpr constant static uint32_t kmask2 = 0x3b603b60; + constexpr constant static uint32_t ka = 89226354; + constexpr constant static uint32_t kb = 64248484; + constexpr constant static uint32_t ka1 = ka*ka; + constexpr constant static uint32_t kb1 = kb*ka+kb; + constexpr constant static uint32_t ka2 = ka1*ka; + constexpr constant static uint32_t kb2 = kb1*ka+kb; + constexpr constant static uint32_t ka3 = ka2*ka; + constexpr constant static uint32_t kb3 = kb2*ka+kb; + constexpr constant static uint32_t ka4 = ka3*ka; + constexpr constant static uint32_t kb4 = kb3*ka+kb; + constexpr constant static uint32_t ka5 = ka4*ka; + constexpr constant static uint32_t kb5 = kb4*ka+kb; + constexpr constant static uint32_t ka6 = ka5*ka; + constexpr constant static uint32_t kb6 = kb5*ka+kb; + constexpr constant static uint32_t ka7 = ka6*ka; + constexpr constant static uint32_t kb7 = kb6*ka+kb; + + static inline half4 gen4(uint32_t val) { + thread uint32_t aux[4] = {((ka *val + kb ) & kmask1) ^ kmask2, + ((ka1*val + kb1) & kmask1) ^ kmask2, + ((ka2*val + kb2) & kmask1) ^ kmask2, + ((ka3*val + kb3) & kmask1) ^ kmask2}; + const thread half * h = (const thread half *)aux; + return { h[0]+h[1], h[2]+h[3], h[4]+h[5], h[6]+h[7] }; + } + template <typename T4> + static inline void gen8(uint32_t val, thread T4& v1, thread T4& v2) { + thread uint32_t aux[8] = {((ka *val + kb ) & kmask1) ^ kmask2, + ((ka1*val + kb1) & kmask1) ^ kmask2, + ((ka2*val + kb2) & kmask1) ^ kmask2, + ((ka3*val + kb3) & kmask1) ^ kmask2, + ((ka4*val + kb4) & kmask1) ^ kmask2, + ((ka5*val + kb5) & kmask1) ^ kmask2, + ((ka6*val + kb6) & kmask1) ^ kmask2, + ((ka7*val + kb7) & kmask1) ^ kmask2}; + const thread half * h = (const thread half *)aux; + if constexpr (is_same_v<T4, half4>) { + v1 = { h[0]+h[1], h[2]+h[3], h[4]+h[5], h[6]+h[7] }; + v2 = { h[8]+h[9], h[10]+h[11], h[12]+h[13], h[14]+h[15] }; + } else { + v1 = { (float)(h[0]+h[1]), (float)(h[ 2]+h[ 3]), (float)(h[ 4]+h[ 5]), (float)(h[ 6]+h[ 7]) }; + v2 = { (float)(h[8]+h[9]), (float)(h[10]+h[11]), (float)(h[12]+h[13]), (float)(h[14]+h[15]) }; + } + } +}; + +void kernel_mul_mv_iq2_kt_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const uint row_size = sizeof(float) + nb*sizeof(block_iq2_kt); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(ne01) + (i13/r3)*(ne01*ne02); + + device const char * cx = (device const char *) src0 + (first_row + offset0)*row_size; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float4 sumf={0.f}; + + const int ix = tiisg/16; // 0...1 + const int it = tiisg%16; // 0...15 + + device const float4 * y4 = (device const float4 *)y + ix * (QK_K/4) + 4 * it; + + float4 v1, v2; + + float drow[N_DST]; + for (int row = 0; row < N_DST; ++row) { + device const float * dptr = (device const float *)(cx + row*row_size); + drow[row] = dptr[0] * 31.75f * 1.05f; + } + + device const block_iq2_kt * x = (device const block_iq2_kt *)(cx + sizeof(float)); + + for (int ib = ix; ib < nb; ib += 2) { + + device const uint8_t * sc = (device const uint8_t *)x[ib].scales; + + for (int row = 0; row < N_DST; row++) { + + device const uint16_t * q2 = (device const uint16_t *)(sc + 4); + + const float ls = drow[row] * iq4k_values[(sc[(it/2)%4] >> 4*(it/8)) & 0xf]; + + Trellis::gen8(q2[2*it+0]+4096, v1, v2); + auto sum = v1*y4[0] + v2*y4[1]; + + Trellis::gen8(q2[2*it+1]+4096, v1, v2); + sum += v1*y4[2] + v2*y4[3]; + + sum *= ls; + + sumf[row] += sum[0] + sum[1] + sum[2] + sum[3]; + + sc += row_size; + + } + + y4 += QK_K/2; + } + + sumf = simd_sum(sumf); + if (tiisg < 4) { + dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg]; + } + +} + +[[host_name("kernel_mul_mv_iq2_kt_f32")]] +kernel void kernel_mul_mv_iq2_kt_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_kt_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_iq3_kt_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const uint row_size = sizeof(float) + nb*sizeof(block_iq3_kt); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(ne01) + (i13/r3)*(ne01*ne02); + + device const char * cx = (device const char *) src0 + (first_row + offset0)*row_size; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float4 sumf={0.f}; + + const int ix = tiisg/16; // 0...1 + const int it = tiisg%16; // 0...15 + + device const float4 * y4 = (device const float4 *)y + ix * (QK_K/4) + 4 * it; + + float4 v[2]; + thread uint32_t * u32 = (thread uint32_t *)v; + + float drow[N_DST]; + for (int row = 0; row < N_DST; ++row) { + device const float * dptr = (device const float *)(cx + row*row_size); + drow[row] = dptr[0] * 31.75f * 1.01f; + } + + device const block_iq3_kt * x = (device const block_iq3_kt *)(cx + sizeof(float)); + + for (int ib = ix; ib < nb; ib += 2) { + + device const uint8_t * sc = (device const uint8_t *)x[ib].scales; + + for (int row = 0; row < N_DST; row++) { + + device const uint16_t * q2 = (device const uint16_t *)(sc + 4); + device const uint8_t * qh = (device const uint8_t *)(q2 + QK_K/8) + 16*(it%2); + + const float ls = drow[row] * ((sc[(it/2)%4] >> 4*(it/8)) & 0xf); + const uint8_t mask = 1 << (it/2); + + Trellis::gen8(q2[2*it+0]+4096, v[0], v[1]); + for (int j = 0; j < 8; ++j) { + u32[j] &= 0x7fffffff; + u32[j] |= qh[j+0] & mask ? 0x80000000 : 0; + } + + auto sum = v[0]*y4[0] + v[1]*y4[1]; + + Trellis::gen8(q2[2*it+1]+4096, v[0], v[1]); + for (int j = 0; j < 8; ++j) { + u32[j] &= 0x7fffffff; + u32[j] |= qh[j+8] & mask ? 0x80000000 : 0; + } + + sum += v[0]*y4[2] + v[1]*y4[3]; + + sum *= ls; + + sumf[row] += sum[0] + sum[1] + sum[2] + sum[3]; + + sc += row_size; + + } + + y4 += QK_K/2; + } + + sumf = simd_sum(sumf); + if (tiisg < 4) { + dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg]; + } + +} + +[[host_name("kernel_mul_mv_iq3_kt_f32")]] +kernel void kernel_mul_mv_iq3_kt_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq3_kt_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + +//TODO +void kernel_mul_mv_iq4_kt_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + threadgroup int8_t * shared_values, + uint3 tgpig, + uint tiisg, + uint sgitg) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const uint row_size = 2*sizeof(float) + nb*sizeof(block_iq4_kt); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(ne01) + (i13/r3)*(ne01*ne02); + + device const char * cx = (device const char *) src0 + (first_row + offset0)*row_size; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float4 sumf={0.f}; + + const int ix = tiisg/16; // 0...1 + const int it = tiisg%16; // 0...15 + + device const float4 * y4 = (device const float4 *)y + ix * (QK_K/4) + 4 * it; + + float4 v[2]; + thread uint32_t * u32 = (thread uint32_t *)v; + + //float drow[2*N_DST]; + //for (int row = 0; row < N_DST; ++row) { + // device const float * dptr = (device const float *)(cx + row*row_size); + // drow[2*row+0] = dptr[0] * 31.75f * 1.01f; + // drow[2*row+1] = dptr[1]; + //} + float drow[N_DST]; + for (int row = 0; row < N_DST; ++row) { + device const float * dptr = (device const float *)(cx + row*row_size); + drow[row] = dptr[0] * 31.75f * 1.01f; + } + + device const block_iq4_kt * x = (device const block_iq4_kt *)(cx + 2*sizeof(float)); + + for (int ib = ix; ib < nb; ib += 2) { + + //auto sumy = y4[0] + y4[1] + y4[2] + y4[3]; + + device const uint32_t * shb = x[ib].qs; + + for (int row = 0; row < N_DST; row++) { + + device const uint8_t * ql = (device const uint8_t *)(shb + 8); + device const uint8_t * qh = ql + 64; + + const float ls = drow[row] * (((shb[it/2] & 0xff) >> 1) - 64); + + const int jj = 8*(it/2) + 4*(it%2); + ql += jj; + qh += jj%32; + + const uint32_t offset = 4096 + ((shb[it/2] & 1) << 15); + const int shift = 8 - 4*(jj/32); + uint32_t sh = (shb[it/2] >> (8 + 12*(it%2))) << 12; + + float4 sum = {0.f}; + for (int j = 0; j < 4; ++j) { + uint32_t idx = ql[j] + ((qh[j] << shift) & 0xf00) + ((sh >> 3*j) & 0x7000) + offset; + auto v = Trellis::gen4(idx); + sum += y4[j] * (float4)v; + } + sum *= ls; + + //sumf[row] += sum[0] + sum[1] + sum[2] + sum[3] + drow[2*row+1]*(sumy[0] + sumy[1] + sumy[2] + sumy[3]); + sumf[row] += sum[0] + sum[1] + sum[2] + sum[3]; + + shb += row_size/4; + + } + + y4 += QK_K/2; + } + + sumf = simd_sum(sumf); + if (tiisg < 4) { + dst[r1*ne0 + im*ne0*ne1 + first_row + tiisg] = sumf[tiisg]; + } + +} + +[[host_name("kernel_mul_mv_iq4_kt_f32")]] +kernel void kernel_mul_mv_iq4_kt_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup int8_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_kt_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} + + [[host_name("kernel_mul_mv_iq2_k_f32")]] kernel void kernel_mul_mv_iq2_k_f32( device const void * src0, @@ -8114,6 +8539,71 @@ void dequantize_iq4_kss(device const block_iq4_kss * xb, short il, thread type4x } template <typename type4x4> +void dequantize_iq2_kt(device const block_iq2_kt * x, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 + int ib32 = il/2; + half scale = iq4k_values[((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf)] * 31.75h * 1.05h; + device const uint16_t * q2 = (device const uint16_t *)x->ql + 4*ib32 + 2*(il%2); + + half4 v1, v2; + for (int i = 0; i < 2; ++i) { + Trellis::gen8(q2[i]+4096, v1, v2); + v1 *= scale; v2 *= scale; + if constexpr (is_same_v<type4x4, half4x4>) { + reg[2*i+0] = v1; + reg[2*i+1] = v2; + } else { + reg[2*i+0] = {(float)v1[0], (float)v1[1], (float)v1[2], (float)v1[3]}; + reg[2*i+1] = {(float)v2[0], (float)v2[1], (float)v2[2], (float)v2[3]}; + } + } +} + +template <typename type4x4> +void dequantize_iq3_kt(device const block_iq3_kt * x, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 + int ib32 = il/2; + half scale = (half)((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf) * 31.75h * 1.01h; + device const uint16_t * q2 = (device const uint16_t *)x->ql + 4*ib32 + 2*(il%2); + device const uint8_t * qh = x->qh + 16*(il%2); + const uint8_t mask = 1 << ib32; + + half4 v1, v2; + for (int i = 0; i < 2; ++i) { + Trellis::gen8(q2[i]+4096, v1, v2); + //v1 *= scale; v2 *= scale; + //for (int j = 0; j < 4; ++j) reg[2*i+0][j] = qh[8*i+0+j] & mask ? -abs(v1[j]) : abs(v1[j]); + //for (int j = 0; j < 4; ++j) reg[2*i+1][j] = qh[8*i+4+j] & mask ? -abs(v2[j]) : abs(v2[j]); + v1 = abs(v1)*scale; v2 = abs(v2)*scale; + for (int j = 0; j < 4; ++j) reg[2*i+0][j] = qh[8*i+0+j] & mask ? -v1[j] : v1[j]; + for (int j = 0; j < 4; ++j) reg[2*i+1][j] = qh[8*i+4+j] & mask ? -v2[j] : v2[j]; + } +} + +void dequantize_iq4_kt(device const block_iq4_kt * x, short il, float d, thread float4x4 & reg) { + // il is 0...15 for QK_K = 256 + int ib32 = il/2; + device const uint32_t * shb = x->qs; + device const uint8_t * ql = (device const uint8_t *)(shb + 8); + device const uint8_t * qh = ql + 64; + float scale = d * (((shb[ib32] & 0xff) >> 1) - 64); + const uint32_t offset = 4096 + ((shb[ib32] & 1) << 15); + + const int jj = ib32*8 + 4*(il%2); + ql += jj; + qh += jj%32; + + uint32_t sh = (shb[ib32] >> (8 + 12*(il%2))) << 12; + const int shift = 8 - 4*(jj/32); + + for (int i = 0; i < 4; ++i) { + uint32_t idx = ql[i] + ((qh[i] << shift) & 0xf00) + ((sh >> 3*i) & 0x7000) + offset; + auto v = (float4)Trellis::gen4(idx); + reg[i] = v * scale; + } +} + +template <typename type4x4> void dequantize_iq2_k(device const block_iq2_k * xb, short il, thread type4x4 & reg) { // il is 0...15 for QK_K = 256 device const uint32_t * q32 = (device const uint32_t *)xb->qs + 8*(il/8) + 4*(il&1); @@ -8409,6 +8899,61 @@ struct DequantizerRS { Scale d; }; +template <typename T4x4, typename Block, typename Scale, int nl, void (*dequantize)(device const Block *, short, thread T4x4&)> +struct DequantizerRST4 { + using type4x4 = T4x4; + DequantizerRST4(device const char * cx, short il = 0) : il(il) { + device const Scale * dptr = (device const Scale *)cx; + d[0] = dptr[0] * Scale(31.75f * 1.01f); + d[1] = dptr[1]; + x = (device const Block *)(dptr + 2); + } + inline void convert(thread T4x4& t) const { + dequantize(x, il, t); + for (int i = 0; i < 4; ++i) t[i] = t[i]*d[0] + d[1]; + } + inline void convert(int64_t ind, thread T4x4& t) { + dequantize(x + ind/nl, ind%nl, t); + for (int i = 0; i < 4; ++i) t[i] = t[i]*d[0] + d[1]; + } + inline void next() { + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + } + device const Block * x; + short il; + Scale d[2]; +}; + +template <typename T4x4, int nl> +struct DequantizerKT4 { + using Block = block_iq4_kt; + using type4x4 = T4x4; + DequantizerKT4(device const char * cx, short il = 0) : il(il) { + device const float * dptr = (device const float *)cx; + d[0] = dptr[0] * 31.75f * 1.01f; + d[1] = dptr[1]; + x = (device const Block *)(dptr + 2); + } + inline void convert(thread T4x4& t) const { + float4x4 tmp; + dequantize_iq4_kt(x, il, d[0], tmp); + for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j]; + } + inline void convert(int64_t ind, thread T4x4& t) { + float4x4 tmp; + dequantize_iq4_kt(x + ind/nl, ind%nl, d[0], tmp); + for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) t[i][j] = tmp[i][j]; + } + inline void next() { + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + } + device const Block * x; + short il; + float d[2]; +}; + template <typename T4x4, typename Block, typename Scale, int nl, void (*dequantize)(half d, device const Block *, short, thread T4x4&), bool may_not_be_aligned = false> struct DequantizerRSBN { using type4x4 = T4x4; @@ -8849,6 +9394,9 @@ template [[host_name("kernel_get_rows_iq4_ks")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_iq5_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq5_ks, float, 16, dequantize_iq5_ks>>; template [[host_name("kernel_get_rows_iq4_kss")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>; template [[host_name("kernel_get_rows_iq2_ks")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>; +template [[host_name("kernel_get_rows_iq2_kt")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>>; +template [[host_name("kernel_get_rows_iq3_kt")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerRS<float4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>>; +template [[host_name("kernel_get_rows_iq4_kt")]] kernel get_rows_q_t kernel_get_rows_q2<DequantizerKT4<float4x4, 16>>; // // matrix-matrix multiplication @@ -8893,6 +9441,9 @@ template [[host_name("kernel_mul_mm_iq4_ks_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq5_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq5_ks, float, 16, dequantize_iq5_ks>, float>; template [[host_name("kernel_mul_mm_iq4_kss_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>, float>; template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>, float>; +template [[host_name("kernel_mul_mm_iq2_kt_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>, float>; +template [[host_name("kernel_mul_mm_iq3_kt_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>, float>; +template [[host_name("kernel_mul_mm_iq4_kt_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerKT4<half4x4, 16>, float>; template [[host_name("kernel_mul_mm_f32_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequantize_f32>, half>; template [[host_name("kernel_mul_mm_f16_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_f16>, half>; @@ -8928,6 +9479,9 @@ template [[host_name("kernel_mul_mm_iq4_ks_f16")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq5_ks_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq5_ks, float, 16, dequantize_iq5_ks>, half>; template [[host_name("kernel_mul_mm_iq4_kss_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>, half>; template [[host_name("kernel_mul_mm_iq2_ks_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>, half>; +template [[host_name("kernel_mul_mm_iq2_kt_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>, half>; +template [[host_name("kernel_mul_mm_iq3_kt_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>, half>; +template [[host_name("kernel_mul_mm_iq4_kt_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerKT4<half4x4, 16>, half>; // @@ -8970,6 +9524,9 @@ template [[host_name("kernel_mul_mm_id_iq4_ks_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq5_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq5_ks, float, 16, dequantize_iq5_ks>>; template [[host_name("kernel_mul_mm_id_iq4_kss_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>; template [[host_name("kernel_mul_mm_id_iq2_ks_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>; +template [[host_name("kernel_mul_mm_id_iq2_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq2_kt, float, 16, dequantize_iq2_kt>>; +template [[host_name("kernel_mul_mm_id_iq3_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerRS<half4x4, block_iq3_kt, float, 16, dequantize_iq3_kt>>; +template [[host_name("kernel_mul_mm_id_iq4_kt_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DequantizerKT4<half4x4, 16>>; // // matrix-vector multiplication @@ -9188,6 +9745,9 @@ template [[host_name("kernel_mul_mv_id_iq5_ks_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_iq4_kss_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_kss_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq2_k_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_k_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq2_ks_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_ks_f32_impl>>; +template [[host_name("kernel_mul_mv_id_iq2_kt_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_kt_f32_impl>>; +template [[host_name("kernel_mul_mv_id_iq3_kt_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_kt_f32_impl>>; +template [[host_name("kernel_mul_mv_id_iq4_kt_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_kt_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq3_k_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_k_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq4_k_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_k_f32_impl>>; template [[host_name("kernel_mul_mv_id_iq5_k_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq5_k_f32_impl>>; |