summaryrefslogtreecommitdiff
path: root/ggml/src
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src')
-rw-r--r--ggml/src/ggml-metal.m556
-rw-r--r--ggml/src/ggml-metal.metal1263
2 files changed, 1248 insertions, 571 deletions
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index aa50d448..501fe5a2 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -225,6 +225,39 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
@@ -276,9 +309,33 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
- //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
- //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H80,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H112,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
@@ -290,7 +347,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
- GGML_METAL_KERNEL_TYPE_CONCAT,
+ GGML_METAL_KERNEL_TYPE_CONCAT_F32,
+ GGML_METAL_KERNEL_TYPE_CONCAT_F16,
GGML_METAL_KERNEL_TYPE_SQR,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -793,6 +851,39 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32, mul_mm_iq4_k_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32, mul_mm_iq5_k_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32, mul_mm_iq6_k_f32, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F16, mul_mm_f32_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F16, mul_mm_f16_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F16, mul_mm_bf16_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F16, mul_mm_q4_0_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F16, mul_mm_q4_1_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F16, mul_mm_q5_0_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F16, mul_mm_q5_1_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F16, mul_mm_q6_0_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F16, mul_mm_q8_0_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F16, mul_mm_q2_K_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F16, mul_mm_q3_K_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F16, mul_mm_q4_K_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F16, mul_mm_q5_K_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F16, mul_mm_q6_K_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F16, mul_mm_iq2_xxs_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F16, mul_mm_iq2_xs_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F16, mul_mm_iq3_xxs_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F16, mul_mm_iq3_s_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F16, mul_mm_iq2_s_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F16, mul_mm_iq1_s_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F16, mul_mm_iq1_m_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16, mul_mm_iq1_bn_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16, mul_mm_iq2_bn_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16, mul_mm_iq4_nl_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16, mul_mm_iq4_xs_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16, mul_mm_iq4_ks_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16, mul_mm_iq4_kss_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16, mul_mm_iq2_k_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16, mul_mm_iq2_ks_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16, mul_mm_iq3_k_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16, mul_mm_iq4_k_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16, mul_mm_iq5_k_f16, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F16, mul_mm_iq6_k_f16, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, ctx->support_simdgroup_mm);
@@ -844,9 +935,33 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
- //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,flash_attn_ext_f16_hk192_hv128, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,flash_attn_ext_f16_hk576_hv512, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,flash_attn_ext_q8_0_hk192_hv128, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,flash_attn_ext_q8_0_hk576_hv512, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, flash_attn_ext_vec_f16_h80, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, flash_attn_ext_vec_f16_h112, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,flash_attn_ext_vec_f16_hk192_hv128, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,flash_attn_ext_vec_f16_hk576_hv512, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H80, flash_attn_ext_vec_q8_0_h80, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H112, flash_attn_ext_vec_q8_0_h112, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,flash_attn_ext_vec_q8_0_hk192_hv128, ctx->support_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,flash_attn_ext_vec_q8_0_hk576_hv512, ctx->support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
@@ -858,7 +973,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0, cpy_f32_q6_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT_F32, concat_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT_F16, concat_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
}
@@ -1001,17 +1117,24 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
case GGML_OP_LEAKY_RELU:
return true;
case GGML_OP_FLASH_ATTN_EXT:
- if (op->src[1]->type != GGML_TYPE_F16) {
- return false;
+ if (!ctx->support_simdgroup_mm) {
+ return false; // TODO: over-restricted for vec-kernels
}
- if (op->src[2]->type != GGML_TYPE_F16) {
+ if (op->src[1]->type != op->src[2]->type ||
+ (op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_Q8_0)) {
return false;
}
- if (op->src[0]->ne[0] == 256) {
- return false;
+ if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
+ return (op->src[1]->ne[0] == 192 && op->src[2]->ne[0] == 128) ||
+ (op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512);
}
- return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
+ return (op->src[1]->ne[0] == 64 || op->src[1]->ne[0] == 80 ||
+ op->src[1]->ne[0] == 96 || op->src[1]->ne[0] == 112 ||
+ op->src[1]->ne[0] == 128 || op->src[1]->ne[0] == 256);
case GGML_OP_MUL_MAT:
+ return ctx->support_simdgroup_reduction &&
+ (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
+ !(op->src[0]->type >= GGML_TYPE_Q4_0_R8 && op->src[0]->type <= GGML_TYPE_Q8_K_R8);
case GGML_OP_MUL_MAT_ID:
return ctx->support_simdgroup_reduction &&
(op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
@@ -1157,7 +1280,18 @@ static void ggml_metal_encode_node(
switch (dst->op) {
case GGML_OP_CONCAT:
{
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
+ GGML_ASSERT(src0->type == src1->type && src0->type == dst->type);
+
+ id<MTLComputePipelineState> pipeline;
+ if (dst->type == GGML_TYPE_F32) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT_F32].pipeline;
+ }
+ else if (dst->type == GGML_TYPE_F16) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT_F16].pipeline;
+ }
+ else {
+ GGML_ABORT("CONCAT not implemented for this type");
+ }
const int32_t dim = ((int32_t *) dst->op_params)[0];
@@ -1945,7 +2079,7 @@ static void ggml_metal_encode_node(
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
!ggml_is_transposed(src0) &&
!ggml_is_transposed(src1) &&
- src1t == GGML_TYPE_F32 &&
+ (src1t == GGML_TYPE_F32 || src1t == GGML_TYPE_F16) &&
ne00 % 32 == 0 && ne00 >= 64 &&
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
@@ -1960,41 +2094,84 @@ static void ggml_metal_encode_node(
id<MTLComputePipelineState> pipeline = nil;
- switch (src0->type) {
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
- case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F32 ].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
- case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
- case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
- case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
- case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
- case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
- case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
- case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
- case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
- case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
- case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
- case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
- case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
- case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break;
- case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break;
- case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
- case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
- case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break;
- case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32].pipeline; break;
- case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32 ].pipeline; break;
- case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32 ].pipeline; break;
- case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32 ].pipeline; break;
- case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32 ].pipeline; break;
- case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32 ].pipeline; break;
- case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32 ].pipeline; break;
- default: GGML_ABORT("MUL MAT-MAT not implemented");
+ if (src1->type == GGML_TYPE_F32) {
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break;
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break;
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
+ case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32].pipeline; break;
+ case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32 ].pipeline; break;
+ case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32 ].pipeline; break;
+ case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32 ].pipeline; break;
+ default: GGML_ABORT("MUL MAT-MAT not implemented");
+ }
+ }
+ else if (src1->type == GGML_TYPE_F16) {
+ switch (src0->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F16 ].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F16 ].pipeline; break;
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F16 ].pipeline; break;
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F16 ].pipeline; break;
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F16 ].pipeline; break;
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F16 ].pipeline; break;
+ case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F16 ].pipeline; break;
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F16 ].pipeline; break;
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F16 ].pipeline; break;
+ case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F16 ].pipeline; break;
+ case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F16 ].pipeline; break;
+ case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F16 ].pipeline; break;
+ case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F16 ].pipeline; break;
+ case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F16 ].pipeline; break;
+ case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F16].pipeline; break;
+ case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F16 ].pipeline; break;
+ case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F16].pipeline; break;
+ case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F16 ].pipeline; break;
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F16 ].pipeline; break;
+ case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F16 ].pipeline; break;
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F16 ].pipeline; break;
+ case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16 ].pipeline; break;
+ case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16 ].pipeline; break;
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16 ].pipeline; break;
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16 ].pipeline; break;
+ case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16 ].pipeline; break;
+ case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16].pipeline; break;
+ case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16 ].pipeline; break;
+ case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16 ].pipeline; break;
+ case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16 ].pipeline; break;
+ case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16 ].pipeline; break;
+ case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16 ].pipeline; break;
+ case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F16 ].pipeline; break;
+ default: GGML_ABORT("MUL MAT-MAT not implemented");
+ }
+ }
+ else {
+ GGML_ABORT("Unsupported src1 type for MUL-MAT");
}
[encoder setComputePipelineState:pipeline];
@@ -3204,8 +3381,9 @@ static void ggml_metal_encode_node(
GGML_ASSERT(ne11 % 32 == 0);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
-
- GGML_ASSERT(ggml_are_same_shape (src1, src2));
+ GGML_ASSERT(src1->type == src2->type);
+ GGML_ASSERT(ne11 == ne21);
+ GGML_ASSERT(ne12 == ne22);
struct ggml_tensor * src3 = node->src[3];
@@ -3250,70 +3428,189 @@ static void ggml_metal_encode_node(
bool use_vec_kernel = false;
- if (ne01 >= 4 || (ne00%128 != 0)) {
- switch (ne00) {
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
- //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+ if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192 && ne00 != 576)) {
+ switch (src1->type) {
+ case GGML_TYPE_F16:
+ {
+ if (ne00 == 192 && ne20 == 128) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
+ }
+ else if (ne00 == 576 && ne20 == 512) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
+ } else {
+ switch (ne00) {
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00);
+ GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+ GGML_ABORT("add template specialization for this size");
+ }
+ }
+ }
+ } break;
+ case GGML_TYPE_Q8_0:
+ {
+ if (ne00 == 192 && ne20 == 128) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
+ }
+ else if (ne00 == 576 && ne20 == 512) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
+ } else {
+ switch (ne00) {
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00);
+ GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+ GGML_ABORT("add template specialization for this size");
+ }
+ }
+ }
+ } break;
default:
- {
- GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_METAL_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
+ {
+ GGML_METAL_LOG_ERROR("unsupported type: %s\n", ggml_type_name(src1->type));
+ GGML_METAL_LOG_ERROR("add template specialization for this type\n");
+ GGML_ABORT("add template specialization for this type");
+ }
}
} else {
use_vec_kernel = true;
-
- switch (ne00) {
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
- //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+ switch (src1->type) {
+ case GGML_TYPE_F16:
+ {
+ if (ne00 == 192 && ne20 == 128) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline;
+ }
+ else if (ne00 == 576 && ne20 == 512) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline;
+ } else {
+ switch (ne00) {
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64 ].pipeline; break;
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80 ].pipeline; break;
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96 ].pipeline; break;
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112].pipeline; break;
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00);
+ GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+ GGML_ABORT("add template specialization for this size");
+ }
+ }
+ }
+ } break;
+ case GGML_TYPE_Q8_0:
+ {
+ if (ne00 == 192 && ne20 == 128) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline;
+ }
+ else if (ne00 == 576 && ne20 == 512) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline;
+ } else {
+ switch (ne00) {
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64 ].pipeline; break;
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H80 ].pipeline; break;
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96 ].pipeline; break;
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H112].pipeline; break;
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00);
+ GGML_METAL_LOG_ERROR("add template specialization for this size\n");
+ GGML_ABORT("add template specialization for this size");
+ }
+ }
+ }
+ } break;
default:
- {
- GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_METAL_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
+ {
+ GGML_METAL_LOG_ERROR("unsupported type: %s\n", ggml_type_name(src1->type));
+ GGML_METAL_LOG_ERROR("add template specialization for this type\n");
+ GGML_ABORT("add template specialization for this type");
+ }
}
+
}
+ typedef struct {
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne11;
+ int32_t ne_12_2; // assume K and V are same shape
+ int32_t ne_12_3;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ uint64_t nb21;
+ uint64_t nb22;
+ uint64_t nb23;
+ uint64_t nb31;
+ int32_t ne1;
+ int32_t ne2;
+ float scale;
+ float max_bias;
+ float m0;
+ float m1;
+ uint16_t n_head_log2;
+ float logit_softcap;
+ } ggml_metal_kargs_flash_attn_ext;
+
+ ggml_metal_kargs_flash_attn_ext args = {
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne11 =*/ ne11,
+ /*.ne_12_2 =*/ ne12,
+ /*.ne_12_3 =*/ ne13,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.nb21 =*/ nb21,
+ /*.nb22 =*/ nb22,
+ /*.nb23 =*/ nb23,
+ /*.nb31 =*/ nb31,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.scale =*/ scale,
+ /*.max_bias =*/ max_bias,
+ /*.m0 =*/ m0,
+ /*.m1 =*/ m1,
+ /*.n_head_log2 =*/ n_head_log2,
+ /*.logit_softcap =*/ softcap,
+ };
+
[encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
if (id_src3) {
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
} else {
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
}
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
- [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
- [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
- [encoder setBytes:&scale length:sizeof( float) atIndex:23];
- [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
- [encoder setBytes:&softcap length:sizeof(softcap) atIndex:27];
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:28];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:5];
if (!use_vec_kernel) {
// half8x8 kernel
@@ -3324,10 +3621,19 @@ static void ggml_metal_encode_node(
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
+ // 2*(2*ncpsg + nqptg)*(nsg)
+ // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
+ //
+ // 16*32*(nsg)
+ // the shared memory needed for the simdgroups to load the KV cache
+ // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
+ //
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
+
int64_t nsgmax = 2;
while (true) {
- const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
+ const size_t smem = FATTN_SMEM(nsgmax);
if (smem > ctx->device.maxThreadgroupMemoryLength) {
break;
}
@@ -3338,14 +3644,14 @@ static void ggml_metal_encode_node(
// simdgroups per threadgroup (a.k.a. warps)
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
- const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
+ const size_t smem = FATTN_SMEM(nsg);
- //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
+ //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
-
- [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
-
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
+#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+
} else {
// half1x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
@@ -3355,8 +3661,27 @@ static void ggml_metal_encode_node(
GGML_ASSERT(nqptg % 1 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
+ // ne00 + 2*ncpsg*(nsg)
+ // for each query, we load it as f16 in shared memory (ne00)
+ // and store the soft_max values and the mask
+ //
+ // ne00*(nsg)
+ // each simdgroup has a full f16 head vector in shared mem to accumulate results
+ //
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
+
+ int64_t nsgmax = 2;
+ while (true) {
+ const size_t smem = FATTN_SMEM(nsgmax);
+ if (smem > ctx->device.maxThreadgroupMemoryLength) {
+ break;
+ }
+ nsgmax *= 2;
+ }
+ nsgmax /= 2;
+
// simdgroups per threadgroup (a.k.a. warps)
- const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
+ const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
int64_t nsg = 1;
while (nsg <= nsgt) {
@@ -3364,13 +3689,14 @@ static void ggml_metal_encode_node(
}
nsg /= 2;
- const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
+ const size_t smem = FATTN_SMEM(nsg);
- //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
+ //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
- [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
-
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
+#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+
}
} break;
case GGML_OP_DUP:
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index a67ec336..d3a2858c 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -2576,262 +2576,358 @@ kernel void kernel_leaky_relu_f32(
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
}
-typedef void (flash_attn_ext_f16_t)(
- device const char * q,
- device const char * k,
- device const char * v,
- device const char * mask,
- device float * dst,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant uint64_t & nb21,
- constant uint64_t & nb22,
- constant uint64_t & nb23,
- constant uint64_t & nb31,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant float & scale,
- constant float & max_bias,
- constant float & m0,
- constant float & m1,
- constant float & softcap,
- constant uint32_t & n_head_log2,
- threadgroup half * shared,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]);
+//==========================================================================================
+// NOTE: this is not dequantizing - we are simply fitting the template
+template <typename type4x4>
+void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
+ reg = (type4x4)(*src);
+}
+
+template <typename type4x4>
+void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
+ reg = (type4x4)(*src);
+}
+
+template <typename type4>
+void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
+ reg = (type4)(*(src));
+}
+
+template <typename type4x4>
+void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
+ if constexpr (is_same_v<type4x4, half4x4>) {
+ const half d = xb->d;
+ for (int i = 0; i < 16; i++) {
+ reg[i/4][i%4] = (half)qs[i + 16*il] * d;
+ }
+ } else {
+ const float d = xb->d;
+ for (int i = 0; i < 16; i++) {
+ reg[i/4][i%4] = qs[i + 16*il] * d;
+ }
+ }
+}
+
+template <typename type4>
+void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
+ const float d = xb->d;
+ for (int i = 0; i < 4; i++) {
+ reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
+ }
+}
+
+typedef struct {
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne11;
+ int32_t ne_12_2; // assume K and V are same shape
+ int32_t ne_12_3;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ uint64_t nb21;
+ uint64_t nb22;
+ uint64_t nb23;
+ uint64_t nb31;
+ int32_t ne1;
+ int32_t ne2;
+ float scale;
+ float max_bias;
+ float m0;
+ float m1;
+ uint16_t n_head_log2;
+ float logit_softcap;
+} ggml_metal_kargs_flash_attn_ext;
// ref: https://arxiv.org/pdf/2307.08691.pdf
-template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
-kernel void kernel_flash_attn_ext_f16(
- device const char * q,
- device const char * k,
- device const char * v,
- device const char * mask,
- device float * dst,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant uint64_t & nb21,
- constant uint64_t & nb22,
- constant uint64_t & nb23,
- constant uint64_t & nb31,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant float & scale,
- constant float & max_bias,
- constant float & m0,
- constant float & m1,
- constant float & softcap,
- constant uint32_t & n_head_log2,
- threadgroup half * shared [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+template<
+ typename q_t, // query types in shared memory
+ typename q4_t,
+ typename q8x8_t,
+ typename k_t, // key types in shared memory
+ typename k4x4_t,
+ typename k8x8_t,
+ typename v_t, // value types in shared memory
+ typename v4x4_t,
+ typename v8x8_t,
+ typename qk_t, // Q*K types
+ typename qk8x8_t,
+ typename s_t, // soft-max types
+ typename s8x8_t,
+ typename o_t, // attention accumulation types
+ typename o4_t,
+ typename o8x8_t,
+ typename kd4x4_t, // key type in device memory
+ short nl_k,
+ void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
+ typename vd4x4_t, // key type in device memory
+ short nl_v,
+ void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
+ short DK, // K head size
+ short DV, // V head size
+ short Q = 8, // queries per threadgroup
+ short KV = 8, // key/value processed per each simdgroup
+ short C = 32> // cache items per threadgroup
+kernel void kernel_flash_attn_ext(
+ constant ggml_metal_kargs_flash_attn_ext & args,
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device char * dst,
+ threadgroup half * shmem_f16 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
- const short iq3 = tgpig[2];
- const short iq2 = tgpig[1];
- const short iq1 = tgpig[0]*Q;
+ const int iq3 = tgpig[2];
+ const int iq2 = tgpig[1];
+ const int iq1 = tgpig[0]*Q;
+
+ constexpr short DK4 = DK/4;
+ constexpr short DK8 = DK/8;
+ constexpr short DK16 = DK/16;
+ constexpr short DV4 = DV/4;
+ constexpr short DV8 = DV/8;
+ constexpr short DV16 = DV/16;
+
+ constexpr short NW = N_SIMDWIDTH;
+ constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) default: 72
- const short D4 = D/4;
- const short D8 = D/8;
- //const short Q8 = Q/8;
- const short NW = N_SIMDWIDTH;
- const short SH = (C + Q); // shared memory per simdgroup in (half)
+ const short TS = nsg*SH; // shared memory size per query in (s_t == float) = 288 for nsg = 4
+ const short T = DK + 2*TS; // shared memory size per query in (half) = 1152 for nsg = 4 and DK = 576
+ // => Q*T is 9216 for nsg = 4 and DK = 576 => 18432 bytes => overflows the 16384 bytes predicted as shmem in ggml-metal.m
- const short T = D + 2*nsg*SH; // shared memory size per query in (half)
- const short TF = T/2; // shared memory size per query in (float)
- const short T4 = T/4; // shared memory size per query in (half4)
+ threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
+ threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
- threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
- threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
- threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+ threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
+ threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
+
+ threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
+ threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
- simdgroup_half8x8 lo[D8];
+ o8x8_t lo[DV8]; // For DV = 512 we have DV8 = 64 => 4096 entries per thread. Do we even have so much
// load heads from Q to shared memory
for (short j = sgitg; j < Q; j += nsg) {
- device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
- for (short i = tiisg; i < D4; i += NW) {
- if (iq1 + j < ne01) {
- sq4[j*T4 + i] = (half4) q4[i];
+ for (short i = tiisg; i < DK4; i += NW) {
+ if (iq1 + j < args.ne01) {
+ sq4[j*DK4 + i] = (q4_t) q4[i];
} else {
- sq4[j*T4 + i] = 0.0h;
+ sq4[j*DK4 + i] = (q4_t) 0.0f;
}
}
}
// zero out lo
- for (short i = 0; i < D8; ++i) {
- lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
+ for (short i = 0; i < DV8; ++i) {
+ lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
}
// zero out shared memory SH
for (short j = 0; j < Q; ++j) {
for (short i = tiisg; i < SH; i += NW) {
- ss[j*TF + i] = 0.0f;
+ ss[j*TS + i] = 0.0f;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
{
- float S[Q] = { [0 ... Q-1] = 0.0h };
- float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
-
- // assume K and V are same shape
- const short ne22 = ne12;
- const short ne23 = ne13;
-
- // broadcast
- const short rk2 = ne02/ne12;
- const short rk3 = ne03/ne13;
-
- const short rv2 = ne02/ne22;
- const short rv3 = ne03/ne23;
+ float S[Q] = { [0 ... Q-1] = 0.0f };
+ float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
- // k indices
- const short ik2 = iq2/rk2;
- const short ik3 = iq3/rk3;
+ // thread indices inside the simdgroup
+ // TODO: see if we can utilize quad-group functions for better performance
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3)
+ const short tx = tiisg%4;
+ const short ty = tiisg/4;
- // v indices
- const short iv2 = iq2/rv2;
- const short iv3 = iq3/rv3;
+ // broadcast kv
+ //const short rk2 = args.ne02/args.ne12;
+ //const short rk3 = args.ne03/args.ne13;
- // load the queries from shared memory into local memory
- simdgroup_half8x8 mq[D8];
+ const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+ const short ikv3 = iq3/(args.ne03/args.ne_12_3);
- for (short i = 0; i < D8; ++i) {
- simdgroup_load(mq[i], sq + i*8, T);
- }
-
- // pointer to the mask
- device const half * mp = (device const half *) (mask + iq1*nb31);
+ const bool has_mask = mask != q;
float slope = 1.0f;
// ALiBi
- if (max_bias > 0.0f) {
- const uint32_t h = iq2;
+ if (args.max_bias > 0.0f) {
+ const short h = iq2;
- const float base = h < n_head_log2 ? m0 : m1;
- const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
+ const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exph);
}
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
- for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+ for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
const int ic = ic0 + C*sgitg;
- if (ic >= ne11) {
+ if (ic >= args.ne11) {
break;
}
- // Q*K^T
- {
- for (short cc = 0; cc < C/8; ++cc) {
- simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
+ if (has_mask) {
+ // used to detect blocks full of -INF
+ float smax = -INFINITY;
- device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
+ // load the mask in shared memory
+ #pragma unroll(Q)
+ for (short j = 0; j < Q; ++j) {
+ device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
- for (short i = 0; i < D8; ++i) {
- simdgroup_half8x8 mk;
- simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
+ const float m = pm[ic + tiisg];
- simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
- }
+ ss[j*TS + C + tiisg] = m;
+ smax = max(smax, m);
+ }
- simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
+ smax = simd_max(smax);
- const short tx = tiisg%4;
- const short ty = tiisg/4;
+ if (smax == -INFINITY) {
+ continue;
+ }
+ }
- // mqk = mqk*scale
- ss[8*cc + ty*TF + 2*tx + 0] *= scale;
- ss[8*cc + ty*TF + 2*tx + 1] *= scale;
+ // Q*K^T
+ {
+ for (short cc = 0; cc < C/8; ++cc) {
+ qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
- if (softcap != 0.0f) {
- ss[8*cc + ty*TF + 2*tx + 0] = softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
- ss[8*cc + ty*TF + 2*tx + 1] = softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
- }
+ // this is compile-time check, so it does not have runtime overhead
+ if (is_same<kd4x4_t, k4x4_t>::value) {
+ // we can read directly from global memory
+ device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
+
+ #pragma unroll(DK8)
+ for (short i = 0; i < DK8; ++i) {
+ k8x8_t mk;
+ simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
- if (mask != q) {
- // mqk = mqk*scale + mask*slope
- ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
- ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
+ q8x8_t mq;
+ simdgroup_load(mq, sq + i*8, DK);
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+ }
+ } else {
+ for (short ii = 0; ii < DK16; ii += 4) {
+ device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
+
+ if (DK16%4 == 0) {
+ // the head is evenly divisible by 4*16 = 64, so no need for bound checks
+ {
+ k4x4_t tmp;
+ deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+ sk4x4[4*ty + tx] = tmp;
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ #pragma unroll(4)
+ for (short k = 0; k < 4; ++k) {
+ k8x8_t mk;
+ q8x8_t mq;
+
+ simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
+ simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+
+ simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
+ simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+ }
+ } else {
+ if (ii + tx < DK16) {
+ k4x4_t tmp;
+ deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+ sk4x4[4*ty + tx] = tmp;
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (short k = 0; k < 4 && ii + k < DK16; ++k) {
+ k8x8_t mk;
+ q8x8_t mq;
+
+ simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
+ simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+
+ simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
+ simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+ }
+ }
+ }
}
+
+ // cast qk_t -> s_t
+ //s8x8_t mqks(1.0f);
+ //simdgroup_multiply(mqks, mqk, mqks);
+ //simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
+
+ simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
}
}
- // used to detect blocks full of -INF
- float smax = -INFINITY;
-
// online softmax
{
- float ms[Q];
+ for (ushort j = 0; j < Q; ++j) {
+ const float m = M[j];
- for (short j = 0; j < Q; ++j) {
- const short p = tiisg;
+ // scale and apply the logitcap / mask
+ float s = ss[j*TS + tiisg]*args.scale;
- const float m = M[j];
- const float s = ss[j*TF + p];
+ if (args.logit_softcap != 0.0f) {
+ s = args.logit_softcap*precise::tanh(s);
+ }
+
+ // mqk = mqk + mask*slope
+ s += slope*ss[j*TS + C + tiisg];
- smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s));
- ms[j] = exp(m - M[j]);
- const float vs = exp(s - M[j]);
+ const float ms = exp(m - M[j]);
+ const float vs = exp(s - M[j]);
- S[j] = S[j]*ms[j] + simd_sum(vs);
+ S[j] = S[j]*ms + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns)
- ss[j*TF + p] = vs;
- }
+ ss[j*TS + tiisg] = vs;
- // create a QxQ diagonal matrix for rescaling the output
- if (tiisg < Q) {
- ss[tiisg*TF + C + tiisg] = ms[tiisg];
+ // create a QxQ diagonal matrix for rescaling the output
+ if (tiisg == j) {
+ ss[j*TS + 2*C + j] = ms;
+ }
}
}
- // skip -INF blocks
- if (smax == -INFINITY) {
- continue;
- }
-
// O = diag(ms)*O
{
- simdgroup_float8x8 mm;
- simdgroup_load(mm, ss + C, TF, 0, false);
+ s8x8_t mm;
+ simdgroup_load(mm, ss + 2*C, TS, 0, false);
- for (short i = 0; i < D8; ++i) {
+ #pragma unroll(DV8)
+ for (short i = 0; i < DV8; ++i) {
simdgroup_multiply(lo[i], mm, lo[i]);
}
}
@@ -2839,16 +2935,64 @@ kernel void kernel_flash_attn_ext_f16(
// O = O + (Q*K^T)*V
{
for (short cc = 0; cc < C/8; ++cc) {
- device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
+ s8x8_t ms;
+ simdgroup_load(ms, ss + 8*cc, TS, 0, false);
- for (short i = 0; i < D8; ++i) {
- simdgroup_half8x8 mk;
- simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
+ if (is_same<vd4x4_t, v4x4_t>::value) {
+ // we can read directly from global memory
+ device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
- simdgroup_float8x8 mv;
- simdgroup_load(mv, ss + 8*cc, TF, 0, false);
+ #pragma unroll(DV8)
+ for (short i = 0; i < DV8; ++i) {
+ v8x8_t mv;
+ simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
- simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
+ simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
+ }
+ } else {
+ for (short ii = 0; ii < DV16; ii += 4) {
+ device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
+
+ if (DV16%4 == 0) {
+ // no need for bound checks
+ {
+ v4x4_t tmp;
+ deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
+ sv4x4[4*ty + tx] = tmp;
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ #pragma unroll(4)
+ for (short k = 0; k < 4; ++k) {
+ v8x8_t mv;
+
+ simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+
+ simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+ }
+ } else {
+ if (ii + tx < DV16) {
+ v4x4_t tmp;
+ deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
+ sv4x4[4*ty + tx] = tmp;
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (short k = 0; k < 4 && ii + k < DV16; ++k) {
+ v8x8_t mv;
+
+ simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+
+ simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+ }
+ }
+ }
}
}
}
@@ -2857,23 +3001,23 @@ kernel void kernel_flash_attn_ext_f16(
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (short j = 0; j < Q; ++j) {
if (tiisg == 0) {
- ss[j*TF + 0] = S[j];
- ss[j*TF + 1] = M[j];
+ ss[j*TS + 0] = S[j];
+ ss[j*TS + 1] = M[j];
}
}
}
// reduce the warps sequentially
- for (short sg = 1; sg < nsg; ++sg) {
- float S = { 0.0h };
- float M = { -FLT_MAX/2 };
+ for (ushort sg = 1; sg < nsg; ++sg) {
+ float S = { 0.0f };
+ float M = { -__FLT16_MAX__/2 };
threadgroup_barrier(mem_flags::mem_threadgroup);
// each simdgroup stores its output to shared memory, reusing sq
if (sgitg == sg) {
- for (short i = 0; i < D8; ++i) {
- simdgroup_store(lo[i], sq + i*8, T, 0, false);
+ for (short i = 0; i < DV8; ++i) {
+ simdgroup_store(lo[i], so + i*8, DV, 0, false);
}
}
@@ -2882,11 +3026,11 @@ kernel void kernel_flash_attn_ext_f16(
// the first simdgroup accumulates the results from the other simdgroups
if (sgitg == 0) {
for (short j = 0; j < Q; ++j) {
- const float S0 = ss[j*TF + 0];
- const float S1 = ss[j*TF + sg*SH + 0];
+ const float S0 = ss[j*TS + 0];
+ const float S1 = ss[j*TS + sg*SH + 0];
- const float M0 = ss[j*TF + 1];
- const float M1 = ss[j*TF + sg*SH + 1];
+ const float M0 = ss[j*TS + 1];
+ const float M1 = ss[j*TS + sg*SH + 1];
M = max(M0, M1);
@@ -2896,25 +3040,27 @@ kernel void kernel_flash_attn_ext_f16(
S = S0*ms0 + S1*ms1;
if (tiisg == 0) {
- ss[j*TF + 0] = S;
- ss[j*TF + 1] = M;
+ ss[j*TS + 0] = S;
+ ss[j*TS + 1] = M;
- ss[j*TF + C + j ] = ms0;
- ss[j*TF + C + j + sg*SH] = ms1;
+ ss[j*TS + 2*C + j ] = ms0;
+ ss[j*TS + 2*C + j + sg*SH] = ms1;
}
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
{
- simdgroup_half8x8 t;
- simdgroup_float8x8 ms0;
- simdgroup_float8x8 ms1;
+ s8x8_t ms0;
+ s8x8_t ms1;
- simdgroup_load(ms0, ss + C, TF, 0, false);
- simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
+ simdgroup_load(ms0, ss + 2*C, TS, 0, false);
+ simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
- for (short i = 0; i < D8; ++i) {
- simdgroup_load (t, sq + i*8, T, 0, false);
+ #pragma unroll(DV8)
+ for (short i = 0; i < DV8; ++i) {
+ o8x8_t t;
+
+ simdgroup_load (t, so + i*8, DV, 0, false);
simdgroup_multiply(t, ms1, t);
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
@@ -2925,8 +3071,8 @@ kernel void kernel_flash_attn_ext_f16(
// store result to shared memory (reuse sq)
if (sgitg == 0) {
- for (short i = 0; i < D8; ++i) {
- simdgroup_store(lo[i], sq + i*8, T, 0, false);
+ for (short i = 0; i < DV8; ++i) {
+ simdgroup_store(lo[i], so + i*8, DV, 0, false);
}
}
@@ -2934,206 +3080,246 @@ kernel void kernel_flash_attn_ext_f16(
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
- for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
- const float S = ss[j*TF + 0];
+ for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
+ const float S = ss[j*TS + 0];
- for (short i = tiisg; i < D4; i += NW) {
- dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
+ for (short i = tiisg; i < DV4; i += NW) {
+ dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
}
}
}
}
-template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
-template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
-template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
-template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
-template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
-//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
-
-template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
-kernel void kernel_flash_attn_ext_vec_f16(
- device const char * q,
- device const char * k,
- device const char * v,
- device const char * mask,
- device float * dst,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant uint64_t & nb21,
- constant uint64_t & nb22,
- constant uint64_t & nb23,
- constant uint64_t & nb31,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant float & scale,
- constant float & max_bias,
- constant float & m0,
- constant float & m1,
- constant float & softcap,
- constant uint32_t & n_head_log2,
- threadgroup half * shared [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
+// template to be able to explore different combinations
+//
+#define FA_TYPES \
+ half, half4, simdgroup_half8x8, \
+ half, half4x4, simdgroup_half8x8, \
+ half, half4x4, simdgroup_half8x8, \
+ float, simdgroup_float8x8, \
+ float, simdgroup_float8x8, \
+ half, half4, simdgroup_half8x8
+
+typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
+
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
+
+template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
+
+#undef FA_TYPES
+
+template<
+ typename q4_t, // query types in shared memory
+ typename k4_t, // key types in shared memory
+ typename v4_t, // value types in shared memory
+ typename qk_t, // Q*K types
+ typename s_t, // soft-max types
+ typename s4_t,
+ typename o4_t, // attention accumulation types
+ typename kd4_t, // key type in device memory
+ short nl_k,
+ void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
+ typename vd4_t, // key type in device memory
+ short nl_v,
+ void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
+ short DK, // K head size
+ short DV, // V head size
+ short NE = 4, // head elements per thread
+ short Q = 1, // queries per threadgroup
+ short C = 32> // cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec(
+ constant ggml_metal_kargs_flash_attn_ext & args,
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device char * dst,
+ threadgroup half * shmem_f16 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
- const short iq3 = tgpig[2];
- const short iq2 = tgpig[1];
- const short iq1 = tgpig[0];
+ const int iq3 = tgpig[2];
+ const int iq2 = tgpig[1];
+ const int iq1 = tgpig[0];
- const short D4 = D/4;
- const short NW = N_SIMDWIDTH;
- const short SH = (C + Q); // shared memory per simdgroup in (half)
+ constexpr short DK4 = DK/4;
+ constexpr short DV4 = DV/4;
+ constexpr short NW = N_SIMDWIDTH;
+ constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
+ constexpr short SH = 4*C; // shared memory per simdgroup
- const short T = D + 2*nsg*SH; // shared memory size per query in (half)
+ const short T = DK + nsg*SH; // shared memory size per query in (half)
- float slope = 1.0f;
-
- // ALiBi
- if (max_bias > 0.0f) {
- const uint32_t h = iq2;
-
- const float base = h < n_head_log2 ? m0 : m1;
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
- slope = pow(base, exp);
- }
-
- //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
- threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
- threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
- threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
- threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
+ threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
+ threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
- // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
- half4 lo[D4/NW];
+ // store the result for all queries in local memory (the O matrix from the paper)
+ o4_t lo[DV4/NL];
// load heads from Q to shared memory
- device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
- for (short i = tiisg; i < D4; i += NW) {
- if (iq1 < ne01) {
- sq4[i] = (half4) q4[i];
+ for (short i = tiisg; i < DK4; i += NW) {
+ if (iq1 < args.ne01) {
+ sq4[i] = (q4_t) q4[i];
} else {
- sq4[i] = 0.0h;
+ sq4[i] = (q4_t) 0.0f;
}
}
// zero out lo
- for (short i = tiisg; i < D4; i += NW) {
- lo[i/NW] = 0.0h;
+ for (short i = 0; i < DV4/NL; ++i) {
+ lo[i] = (o4_t) 0.0f;
}
// zero out shared memory SH
for (short i = tiisg; i < SH/4; i += NW) {
- ss4[i] = 0.0h;
+ ss4[i] = (s4_t) 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
{
- float S = { 0.0h };
- float M = { -FLT_MAX/2 };
+ float S = 0.0f;
+ float M = -__FLT16_MAX__/2;
- // assume K and V are same shape
- const short ne22 = ne12;
- const short ne23 = ne13;
+ // thread indices inside the simdgroup
+ const short tx = tiisg%NL;
+ const short ty = tiisg/NL;
- // broadcast
- const short rk2 = ne02/ne12;
- const short rk3 = ne03/ne13;
+ // broadcast kv
+ //const short rk2 = args.ne02/args.ne12;
+ //const short rk3 = args.ne03/args.ne13;
- const short rv2 = ne02/ne22;
- const short rv3 = ne03/ne23;
+ const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+ const short ikv3 = iq3/(args.ne03/args.ne_12_3);
- // k indices
- const short ik2 = iq2 / rk2;
- const short ik3 = iq3 / rk3;
+ const bool has_mask = mask != q;
- // v indices
- const short iv2 = iq2 / rv2;
- const short iv3 = iq3 / rv3;
+ // pointer to the mask
+ device const half * pm = (device const half *) (mask + iq1*args.nb31);
- // load the queries from shared memory into local memory
- float4 mq[D4];
+ float slope = 1.0f;
- for (short ii = 0; ii < D4; ii += NW) {
- short i = ii + tiisg;
- mq[i] = (float4)sq4[i];
- }
+ // ALiBi
+ if (args.max_bias > 0.0f) {
+ const short h = iq2;
- // pointer to the mask
- device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
+ const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
+
+ slope = pow(base, exph);
+ }
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
- for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+ for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
const int ic = ic0 + C*sgitg;
- if (ic >= ne11) {
+ if (ic >= args.ne11) {
break;
}
+ if (has_mask) {
+ sm[tiisg] = pm[ic + tiisg];
+ }
+
// Q*K^T
{
-#pragma unroll
- for (short cc = 0; cc < C/4; ++cc) {
- float4 mqk = { 0.0h };
+ // each simdgroup processes 1 query and NE (NW/NL) head elements
+ for (short cc = 0; cc < C/NE; ++cc) {
+ qk_t mqk = 0.0f;
+
+ device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
- device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
+ #pragma unroll(DK4/NL)
+ for (short ii = 0; ii < DK4; ii += NL) {
+ const short i = ii + tx;
-#pragma unroll
- for (short ii = 0; ii < D4; ii += NW) {
- const short i = ii + tiisg;
+ k4_t mk;
+ deq_k_t4(pk + i/nl_k, i%nl_k, mk);
- float4x4 mk;
- mk[0] = (float4)pk4[i + 0*(nb11/8)];
- mk[1] = (float4)pk4[i + 1*(nb11/8)];
- mk[2] = (float4)pk4[i + 2*(nb11/8)];
- mk[3] = (float4)pk4[i + 3*(nb11/8)];
+ // note: this is less precise than the version below
+ //mqka[0] += dot(mq[0], mk[0]);
+ //mqka[1] += dot(mq[1], mk[1]);
+ //mqka[2] += dot(mq[2], mk[2]);
+ //mqka[3] += dot(mq[3], mk[3]);
- mqk += (float4) (mq[i] * mk);
+ //q4x4_t mq = sq4x4[i];
+ //mqka[0] += dot((float4) mq[0], (float4) mk[0]);
+ //mqka[1] += dot((float4) mq[1], (float4) mk[1]);
+ //mqka[2] += dot((float4) mq[2], (float4) mk[2]);
+ //mqka[3] += dot((float4) mq[3], (float4) mk[3]);
+
+ mqk += dot((float4) mk, (float4) sq4[i]);
}
- // reduce the results from the threads in the simdgroup
- mqk += simd_shuffle_down(mqk, 16);
- mqk += simd_shuffle_down(mqk, 8);
- mqk += simd_shuffle_down(mqk, 4);
- mqk += simd_shuffle_down(mqk, 2);
- mqk += simd_shuffle_down(mqk, 1);
+ static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails
+
+ // simdgroup reduce (NE = 4)
+ // [ 0 .. 7] -> [ 0]
+ // [ 8 .. 15] -> [ 8]
+ // [16 .. 23] -> [16]
+ // [24 .. 31] -> [24]
+ if (NE <= 1) {
+ mqk += simd_shuffle_down(mqk, 16);
+ }
+ if (NE <= 2) {
+ mqk += simd_shuffle_down(mqk, 8);
+ }
+ if (NE <= 4) {
+ mqk += simd_shuffle_down(mqk, 4);
+ }
+ if (NE <= 8) {
+ mqk += simd_shuffle_down(mqk, 2);
+ }
+ if (NE <= 16) {
+ mqk += simd_shuffle_down(mqk, 1);
+ }
// mqk = mqk*scale + mask*slope
- if (tiisg == 0) {
- mqk *= scale;
- if (softcap != 0.0f) {
- mqk = softcap*precise::tanh(mqk);
+ if (tx == 0) {
+ mqk *= args.scale;
+
+ if (args.logit_softcap != 0.0f) {
+ mqk = args.logit_softcap*precise::tanh(mqk);
}
- mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
- ss4[cc] = mqk;
- }
+ mqk += sm[NE*cc + ty]*slope;
+ ss[NE*cc + ty] = mqk;
+ }
}
}
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
// online softmax
{
- const short p = tiisg;
-
const float m = M;
- const float s = ss[p];
+ const float s = ss[tiisg];
M = simd_max(max(M, s));
@@ -3143,47 +3329,96 @@ kernel void kernel_flash_attn_ext_vec_f16(
S = S*ms + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns)
- ss[p] = vs;
+ ss[tiisg] = vs;
// O = diag(ms)*O
-#pragma unroll
- for (short ii = 0; ii < D4; ii += NW) {
- const short i = ii + tiisg;
- lo[i/NW] *= ms;
+ #pragma unroll(DV4/NL)
+ for (short ii = 0; ii < DV4; ii += NL) {
+ lo[ii/NL] *= ms;
}
}
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
// O = O + (Q*K^T)*V
{
-#pragma unroll
- for (short cc = 0; cc < C/4; ++cc) {
- device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
-
-#pragma unroll
- for (short ii = 0; ii < D4; ii += NW) {
- const short i = ii + tiisg;
-
- lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
- lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
- lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
- lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
+ //#pragma unroll(C/NE)
+ for (short cc = 0; cc < C/NE; ++cc) {
+ device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
+
+ const s4_t ms(ss[NE*cc + ty]);
+
+ #pragma unroll(DV4/NL)
+ for (short ii = 0; ii < DV4; ii += NL) {
+ const short i = ii + tx;
+
+ v4_t mv;
+ deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
+
+ lo[ii/NL] += o4_t(float4(mv)*float4(ms));
}
}
}
-
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
if (tiisg == 0) {
- ss[0] = S;
- ss[1] = M;
+ ss[0] = (s_t) S;
+ ss[1] = (s_t) M;
+ }
+ }
+
+ // simdgroup reduce (NE = 4)
+ // [ 0, 8, 16, 24] -> [ 0]
+ // [ 1, 9, 17, 25] -> [ 1]
+ // [ 2, 10, 18, 26] -> [ 2]
+ // [ 3, 11, 19, 27] -> [ 3]
+ // [ 4, 12, 20, 28] -> [ 4]
+ // [ 5, 13, 21, 29] -> [ 5]
+ // [ 6, 14, 22, 30] -> [ 6]
+ // [ 7, 15, 23, 31] -> [ 7]
+ for (short ii = 0; ii < DV4; ii += NL) {
+ if (NE > 1) {
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
+ }
+
+ if (NE > 2) {
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
+ }
+
+ if (NE > 4) {
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
+ }
+
+ if (NE > 8) {
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
+ }
+
+ if (NE > 16) {
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
}
}
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
// store results to shared memory
- for (short ii = 0; ii < D4; ii += NW) {
- short i = ii + tiisg;
- sr4[i] = lo[ii/NW];
+ for (short i = tiisg; i < DV4; i += NL) {
+ sr4[i] = lo[i/NL];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -3191,11 +3426,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
// parallel reduce
for (short r = nsg/2; r > 0; r >>= 1) {
if (sgitg < r) {
- const float S0 = ss[ 0];
- const float S1 = ss[r*SH + 0];
+ const float S0 = ss[ 0];
+ const float S1 = ss[r*(SH/2) + 0];
- const float M0 = ss[ 1];
- const float M1 = ss[r*SH + 1];
+ const float M0 = ss[ 1];
+ const float M1 = ss[r*(SH/2) + 1];
const float M = max(M0, M1);
@@ -3210,9 +3445,8 @@ kernel void kernel_flash_attn_ext_vec_f16(
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
- for (short ii = 0; ii < D4; ii += NW) {
- short i = ii + tiisg;
- sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
+ for (short i = tiisg; i < DV4; i += NW) {
+ sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1;
}
}
@@ -3225,15 +3459,50 @@ kernel void kernel_flash_attn_ext_vec_f16(
if (sgitg == 0) {
const float S = ss[0];
- for (short ii = 0; ii < D4; ii += NW) {
- short i = ii + tiisg;
- dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
+ for (short i = tiisg; i < DV4; i += NW) {
+ dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
}
}
}
-template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
-//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
+// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
+// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
+//
+#define FA_TYPES \
+ half4, \
+ half4, \
+ half4, \
+ float, \
+ float, float4, \
+ half4
+
+typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 4>;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h80")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 80, 80, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h80")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 80, 80, 4>;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 112, 112, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h112")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 112, 112, 4>;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 4>;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 4>;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 4>;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 4>;
+
+#undef FA_TYPES
template<typename T0, typename T1>
kernel void kernel_cpy(
@@ -3809,7 +4078,8 @@ kernel void kernel_cpy_f32_iq4_nl(
}
}
-kernel void kernel_concat(
+template <typename src_t>
+static inline void concat_impl(
device const char * src0,
device const char * src1,
device char * dst,
@@ -3849,21 +4119,93 @@ kernel void kernel_concat(
int64_t o[4] = {0, 0, 0, 0};
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
- device const float * x;
+ device const src_t * x;
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
- x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
+ x = (device const src_t *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
} else {
- x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
+ x = (device const src_t *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
}
- device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device src_t * y = (device src_t *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*y = *x;
}
}
+kernel void kernel_concat_f32(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int32_t & dim,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ concat_impl<float>(src0, src1, dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, dim, tgpig, tpitg, ntg);
+}
+
+kernel void kernel_concat_f16(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int32_t & dim,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ concat_impl<half>(src0, src1, dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, dim, tgpig, tpitg, ntg);
+}
+
+
+
void kernel_mul_mv_q2_K_f32_impl(
device const void * src0,
device const float * src1,
@@ -7092,21 +7434,6 @@ kernel void kernel_mul_mv_iq6_k_f32(
//============================= templates and their specializations =============================
// NOTE: this is not dequantizing - we are simply fitting the template
-template <typename type4x4>
-void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
- float4x4 temp = *(((device float4x4 *)src));
- for (int i = 0; i < 16; i++){
- reg[i/4][i%4] = temp[i/4][i%4];
- }
-}
-
-template <typename type4x4>
-void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
- half4x4 temp = *(((device half4x4 *)src));
- for (int i = 0; i < 16; i++){
- reg[i/4][i%4] = temp[i/4][i%4];
- }
-}
template <typename type4x4>
void dequantize_bf16(device const half4x4 * src, short il, thread type4x4 & reg) {
@@ -7221,22 +7548,6 @@ void dequantize_q6_0(device const block_q6_0 *xb, short il, thread type4x4 & reg
}
template <typename type4x4>
-void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
- device const int8_t * qs = ((device const int8_t *)xb->qs);
- if constexpr (is_same_v<type4x4, half4x4>) {
- const half d = xb->d;
- for (int i = 0; i < 16; i++) {
- reg[i/4][i%4] = (half)qs[i + 16*il] * d;
- }
- } else {
- const float d = xb->d;
- for (int i = 0; i < 16; i++) {
- reg[i/4][i%4] = qs[i + 16*il] * d;
- }
- }
-}
-
-template <typename type4x4>
void dequantize_q2_K(device const block_q2_K * xb, short il, thread type4x4 & reg) {
const float d = xb->d;
const float min = xb->dmin;
@@ -7967,7 +8278,7 @@ struct DequantizerRSBN {
};
// each block_q contains 16*nl weights
-template<typename T, typename simdgroup_T8x8, typename Dequantizer>
+template<typename T, typename simdgroup_T8x8, typename Dequantizer, typename src1_t>
kernel void kernel_mul_mm(device const uchar * src0,
device const uchar * src1,
device float * dst,
@@ -8017,8 +8328,8 @@ kernel void kernel_mul_mm(device const uchar * src0,
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
- device const char * cx = (device const char *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0);
- device const float * y = (device const float *)(src1
+ device const char * cx = (device const char *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0);
+ device const src1_t * y = (device const src1_t *)(src1
+ nb12 * im
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
@@ -8038,7 +8349,12 @@ kernel void kernel_mul_mm(device const uchar * src0,
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
}
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+ if (is_same_v<src1_t, float>) {
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+ } else {
+ half2x4 h = *((device half2x4 *)y);
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = (float2x4)h;
+ }
deq.next();
y += BLOCK_SIZE_K;
@@ -8381,41 +8697,76 @@ template [[host_name("kernel_get_rows_iq2_ks")]] kernel get_rows_q_t kernel_get
template <typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
using DD = DefaultDequantizer<half4x4, block_q, nl, dequantize_func>;
-typedef decltype(kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequantize_f32>>) mat_mm_t;
-
-template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequantize_f32>>;
-template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_f16>>;
-template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_bf16>>;
-template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_0, 2, dequantize_q4_0>>;
-template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_1, 2, dequantize_q4_1>>;
-template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_0, 2, dequantize_q5_0>>;
-template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_1, 2, dequantize_q5_1>>;
-template [[host_name("kernel_mul_mm_q6_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q6_0, 2, dequantize_q6_0>>;
-template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q8_0, 2, dequantize_q8_0>>;
-template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q2_K, QK_NL, dequantize_q2_K>>;
-template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q3_K, QK_NL, dequantize_q3_K>>;
-template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_K, QK_NL, dequantize_q4_K>>;
-template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_K, QK_NL, dequantize_q5_K>>;
-template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q6_K, QK_NL, dequantize_q6_K>>;
-template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>>;
-template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_xs, QK_NL, dequantize_iq2_xs>>;
-template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>>;
-template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq3_s, QK_NL, dequantize_iq3_s>>;
-template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_s, QK_NL, dequantize_iq2_s>>;
-template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_s, QK_NL, dequantize_iq1_s>>;
-template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_m, QK_NL, dequantize_iq1_m>>;
-template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_nl, 2, dequantize_iq4_nl>>;
-template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_xs, QK_NL, dequantize_iq4_xs>>;
-template [[host_name("kernel_mul_mm_iq2_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_k, QK_NL, dequantize_iq2_k>>;
-template [[host_name("kernel_mul_mm_iq3_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq3_k, QK_NL, dequantize_iq3_k>>;
-template [[host_name("kernel_mul_mm_iq4_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_k, QK_NL, dequantize_iq4_k>>;
-template [[host_name("kernel_mul_mm_iq5_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq5_k, QK_NL, dequantize_iq5_k>>;
-template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq6_k, QK_NL, dequantize_iq6_k>>;
-template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRSBN<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>>;
-template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRSBN<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>>;
-template [[host_name("kernel_mul_mm_iq4_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>>;
-template [[host_name("kernel_mul_mm_iq4_kss_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>>;
-template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>>;
+typedef decltype(kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequantize_f32>, float>) mat_mm_t;
+
+template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequantize_f32>, float>;
+template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_f16>, float>;
+template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_bf16>, float>;
+template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_0, 2, dequantize_q4_0>, float>;
+template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_1, 2, dequantize_q4_1>, float>;
+template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_0, 2, dequantize_q5_0>, float>;
+template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_1, 2, dequantize_q5_1>, float>;
+template [[host_name("kernel_mul_mm_q6_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q6_0, 2, dequantize_q6_0>, float>;
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q8_0, 2, dequantize_q8_0>, float>;
+template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q2_K, QK_NL, dequantize_q2_K>, float>;
+template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q3_K, QK_NL, dequantize_q3_K>, float>;
+template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_K, QK_NL, dequantize_q4_K>, float>;
+template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_K, QK_NL, dequantize_q5_K>, float>;
+template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q6_K, QK_NL, dequantize_q6_K>, float>;
+template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>, float>;
+template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_xs, QK_NL, dequantize_iq2_xs>, float>;
+template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>, float>;
+template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq3_s, QK_NL, dequantize_iq3_s>, float>;
+template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_s, QK_NL, dequantize_iq2_s>, float>;
+template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_s, QK_NL, dequantize_iq1_s>, float>;
+template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_m, QK_NL, dequantize_iq1_m>, float>;
+template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_nl, 2, dequantize_iq4_nl>, float>;
+template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_xs, QK_NL, dequantize_iq4_xs>, float>;
+template [[host_name("kernel_mul_mm_iq2_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_k, QK_NL, dequantize_iq2_k>, float>;
+template [[host_name("kernel_mul_mm_iq3_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq3_k, QK_NL, dequantize_iq3_k>, float>;
+template [[host_name("kernel_mul_mm_iq4_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_k, QK_NL, dequantize_iq4_k>, float>;
+template [[host_name("kernel_mul_mm_iq5_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq5_k, QK_NL, dequantize_iq5_k>, float>;
+template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq6_k, QK_NL, dequantize_iq6_k>, float>;
+template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRSBN<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>, float>;
+template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRSBN<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>, float>;
+template [[host_name("kernel_mul_mm_iq4_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>, float>;
+template [[host_name("kernel_mul_mm_iq4_kss_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>, float>;
+template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>, float>;
+
+template [[host_name("kernel_mul_mm_f32_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequantize_f32>, half>;
+template [[host_name("kernel_mul_mm_f16_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_f16>, half>;
+template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_bf16>, half>;
+template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_0, 2, dequantize_q4_0>, half>;
+template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_1, 2, dequantize_q4_1>, half>;
+template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_0, 2, dequantize_q5_0>, half>;
+template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_1, 2, dequantize_q5_1>, half>;
+template [[host_name("kernel_mul_mm_q6_0_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q6_0, 2, dequantize_q6_0>, half>;
+template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q8_0, 2, dequantize_q8_0>, half>;
+template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q2_K, QK_NL, dequantize_q2_K>, half>;
+template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q3_K, QK_NL, dequantize_q3_K>, half>;
+template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_K, QK_NL, dequantize_q4_K>, half>;
+template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_K, QK_NL, dequantize_q5_K>, half>;
+template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q6_K, QK_NL, dequantize_q6_K>, half>;
+template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>, half>;
+template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_xs, QK_NL, dequantize_iq2_xs>, half>;
+template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>, half>;
+template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq3_s, QK_NL, dequantize_iq3_s>, half>;
+template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_s, QK_NL, dequantize_iq2_s>, half>;
+template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_s, QK_NL, dequantize_iq1_s>, half>;
+template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq1_m, QK_NL, dequantize_iq1_m>, half>;
+template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_nl, 2, dequantize_iq4_nl>, half>;
+template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_xs, QK_NL, dequantize_iq4_xs>, half>;
+template [[host_name("kernel_mul_mm_iq2_k_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq2_k, QK_NL, dequantize_iq2_k>, half>;
+template [[host_name("kernel_mul_mm_iq3_k_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq3_k, QK_NL, dequantize_iq3_k>, half>;
+template [[host_name("kernel_mul_mm_iq4_k_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq4_k, QK_NL, dequantize_iq4_k>, half>;
+template [[host_name("kernel_mul_mm_iq5_k_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq5_k, QK_NL, dequantize_iq5_k>, half>;
+template [[host_name("kernel_mul_mm_iq6_k_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_iq6_k, QK_NL, dequantize_iq6_k>, half>;
+template [[host_name("kernel_mul_mm_iq1_bn_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRSBN<half4x4, block_iq1_bn, half, 4, dequantize_iq1_bn, true>, half>;
+template [[host_name("kernel_mul_mm_iq2_bn_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRSBN<half4x4, block_iq2_bn, float, 4, dequantize_iq2_bn>, half>;
+template [[host_name("kernel_mul_mm_iq4_ks_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_ks, float, 16, dequantize_iq4_ks>, half>;
+template [[host_name("kernel_mul_mm_iq4_kss_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq4_kss,float, 16, dequantize_iq4_kss>, half>;
+template [[host_name("kernel_mul_mm_iq2_ks_f16")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DequantizerRS<half4x4, block_iq2_ks, half, 16, dequantize_iq2_ks>, half>;
+
//
// indirect matrix-matrix multiplication