diff options
Diffstat (limited to 'ggml/src/ggml-metal.m')
-rw-r--r-- | ggml/src/ggml-metal.m | 38 |
1 files changed, 34 insertions, 4 deletions
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index ff68ee19..e1e49fcb 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -106,6 +106,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_KS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_KS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS, @@ -152,6 +153,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_KS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KSS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ5_KS_F32, @@ -192,6 +194,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_KS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_KS_F32, @@ -229,6 +232,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F32, @@ -266,6 +270,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F16, @@ -303,6 +308,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_KS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_KS_F32, @@ -756,6 +762,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN, get_rows_iq2_bn, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_KS, get_rows_iq3_ks, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS, get_rows_iq4_ks, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS, get_rows_iq4_kss, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_KS, get_rows_iq5_ks, true); @@ -802,6 +809,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_BN_F32, mul_mv_iq2_bn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_KS_F32, mul_mv_iq3_ks_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KS_F32, mul_mv_iq4_ks_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_KSS_F32, mul_mv_iq4_kss_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ5_KS_F32, mul_mv_iq5_ks_f32, ctx->support_simdgroup_reduction); @@ -842,6 +850,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32, mul_mv_id_iq2_bn_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_KS_F32, mul_mv_id_iq3_ks_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32, mul_mv_id_iq4_ks_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32, mul_mv_id_iq4_kss_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_KS_F32, mul_mv_id_iq5_ks_f32, ctx->support_simdgroup_reduction); @@ -879,6 +888,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32, mul_mm_iq2_bn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F32, mul_mm_iq3_ks_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32, mul_mm_iq4_ks_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32, mul_mm_iq4_kss_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F32, mul_mm_iq5_ks_f32, ctx->support_simdgroup_mm); @@ -916,6 +926,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16, mul_mm_iq2_bn_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16, mul_mm_iq4_nl_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16, mul_mm_iq4_xs_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F16, mul_mm_iq3_ks_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16, mul_mm_iq4_ks_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16, mul_mm_iq4_kss_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F16, mul_mm_iq5_ks_f16, ctx->support_simdgroup_mm); @@ -953,6 +964,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32, mul_mm_id_iq2_bn_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_KS_F32, mul_mm_id_iq3_ks_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32, mul_mm_id_iq4_ks_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32, mul_mm_id_iq4_kss_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_KS_F32, mul_mm_id_iq5_ks_f32, ctx->support_simdgroup_mm); @@ -2169,6 +2181,7 @@ static void ggml_metal_encode_node( case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F32 ].pipeline; break; case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break; case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32].pipeline; break; case GGML_TYPE_IQ5_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F32 ].pipeline; break; @@ -2211,6 +2224,7 @@ static void ggml_metal_encode_node( case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16 ].pipeline; break; + case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_KS_F16 ].pipeline; break; case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16 ].pipeline; break; case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16].pipeline; break; case GGML_TYPE_IQ5_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_KS_F16 ].pipeline; break; @@ -2428,6 +2442,12 @@ static void ggml_metal_encode_node( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; } break; + case GGML_TYPE_IQ3_KS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_KS_F32].pipeline; + } break; case GGML_TYPE_IQ4_KS: { nth0 = 4; @@ -2535,8 +2555,9 @@ static void ggml_metal_encode_node( src0t == GGML_TYPE_IQ2_KT|| src0t == GGML_TYPE_IQ3_KT) { //|| src0t == GGML_TYPE_IQ4_KT) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { - const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float); + else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ3_KS) { + const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) + : src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ3_KS ? 32*sizeof(float) : 16*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } @@ -2648,6 +2669,7 @@ static void ggml_metal_encode_node( case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_KS_F32 ].pipeline; break; case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break; case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32].pipeline; break; case GGML_TYPE_IQ5_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_KS_F32 ].pipeline; break; @@ -2849,6 +2871,12 @@ static void ggml_metal_encode_node( nth1 = 2; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; } break; + case GGML_TYPE_IQ3_KS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_KS_F32].pipeline; + } break; case GGML_TYPE_IQ4_KS: { nth0 = 4; @@ -2967,8 +2995,9 @@ static void ggml_metal_encode_node( src0t == GGML_TYPE_IQ2_KT|| src0t == GGML_TYPE_IQ3_KT) { //|| src0t == GGML_TYPE_IQ4_KT) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { - const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float); + else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ3_KS) { + const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) + : src0t == GGML_TYPE_IQ3_K || src0t == GGML_TYPE_IQ3_KS ? 32*sizeof(float) : 16*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } @@ -3036,6 +3065,7 @@ static void ggml_metal_encode_node( case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; + case GGML_TYPE_IQ3_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_KS ].pipeline; break; case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break; case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS].pipeline; break; case GGML_TYPE_IQ5_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_KS ].pipeline; break; |