diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2024-09-17 10:54:42 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-17 10:54:42 +0300 |
commit | 4ee889f15891b6e8785c88c932ec27f129cae285 (patch) | |
tree | 77afce9b1c5e2807bce1553a15514faa6843f747 | |
parent | 2874b984006c6c8d0691ce000dcd9ca2cf9ff6fd (diff) |
BF16 support on Metal (#56)
* BF16 support on Metal
* Faster BF16 Metal dot product
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | ggml/src/ggml-metal.m | 37 | ||||
-rw-r--r-- | ggml/src/ggml-metal.metal | 125 |
2 files changed, 161 insertions, 1 deletions
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 749396c9..02794e3c 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -113,6 +113,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, + GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, @@ -146,6 +148,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, @@ -176,6 +179,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, @@ -206,6 +210,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, @@ -628,6 +633,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F16, mul_mv_bf16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); @@ -661,6 +668,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); @@ -691,6 +699,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32, mul_mv_id_iq6_k_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); @@ -721,6 +730,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32, mul_mm_iq6_k_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); @@ -856,9 +866,12 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs } static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx, const struct ggml_tensor * op) { + for (size_t i = 0, n = 3; i < n; ++i) { if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) { - return false; + if (op->op != GGML_OP_MUL_MAT && op->op != GGML_OP_MUL_MAT_ID && op->op != GGML_OP_GET_ROWS) { + return false; + } } } @@ -1861,6 +1874,7 @@ static enum ggml_status ggml_metal_graph_compute( switch (src0->type) { case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; @@ -1945,6 +1959,19 @@ static enum ggml_status ggml_metal_graph_compute( nrows = 4; } } break; + case GGML_TYPE_BF16: + { + if (src1t == GGML_TYPE_F32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; + } + else if (src1t == GGML_TYPE_F16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F16].pipeline; + } + else { + GGML_ABORT("not implemented"); + } + nrows = 4; + } break; case GGML_TYPE_Q4_0: { nth0 = 8; @@ -2229,6 +2256,7 @@ static enum ggml_status ggml_metal_graph_compute( switch (src0->type) { case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; @@ -2307,6 +2335,13 @@ static enum ggml_status ggml_metal_graph_compute( nth1 = 1; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; } break; + case GGML_TYPE_BF16: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + nth0 = 32; + nth1 = 1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; + } break; case GGML_TYPE_Q4_0: { nth0 = 8; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index e050fdc3..6553f465 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1667,11 +1667,121 @@ kernel void kernel_mul_mv( tiisg); } +template<typename T1> +void kernel_mul_mv_bf16_impl( + device const char * src0, + device const char * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + int64_t ne10, + int64_t ne11, + int64_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, + uint tiisg) { + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_MV_T_T; + const int64_t im = tgpig.z; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const uint16_t * x = (device const uint16_t *) (src0 + offset0); + + typedef union { uint32_t u[4]; float f[4]; } aux_t; + aux_t aux; + + for (int row = 0; row < N_MV_T_T; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + aux.u[0] = x[4*i+0] << 16; + aux.u[1] = x[4*i+1] << 16; + aux.u[2] = x[4*i+2] << 16; + aux.u[3] = x[4*i+3] << 16; + sumf += aux.f[0] * (float)y[4*i+0] + aux.f[1] * (float)y[4*i+1] + aux.f[2] * (float)y[4*i+2] + aux.f[3] * (float)y[4*i+3]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + +template<typename T1> +kernel void kernel_mul_mv_bf16( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_bf16_impl<T1>( + src0, + src1, + dst, + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg); +} + typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t; template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>; template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>; template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>; +template [[host_name("kernel_mul_mv_bf16_f16")]] kernel mul_mv_t kernel_mul_mv_bf16<half>; +template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv_bf16<float>; template<typename T, typename T4> kernel void kernel_mul_mv_1row( @@ -6612,6 +6722,17 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) } template <typename type4x4> +void dequantize_bf16(device const half4x4 * src, short il, thread type4x4 & reg) { + device const uint16_t * src_u16 = (device const uint16_t *)src; + typedef union { uint32_t u; float f; } aux_t; + aux_t aux; + for (int i = 0; i < 16; ++i) { + aux.u = (uint32_t)src_u16[i] << 16; + reg[i/4][i%4] = aux.f; + } +} + +template <typename type4x4> void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); const float d1 = il ? (xb->d / 16.h) : xb->d; @@ -7692,6 +7813,7 @@ template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_ro typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t; +template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_q_t kernel_get_rows_q<half4x4, 1, dequantize_bf16>; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>; template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>; @@ -7732,6 +7854,7 @@ typedef decltype(kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequanti template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<float4x4, 1, dequantize_f32>>; template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_f16>>; +template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<half4x4, 1, dequantize_bf16>>; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_0, 2, dequantize_q4_0>>; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q4_1, 2, dequantize_q4_1>>; template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, simdgroup_half8x8, DD<block_q5_0, 2, dequantize_q5_0>>; @@ -7769,6 +7892,7 @@ typedef decltype(kernel_mul_mm_id<DD<float4x4, 1, dequantize_f32>>) mat_mm_id_t; template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<float4x4, 1, dequantize_f32>>; template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<half4x4, 1, dequantize_f16>>; +template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<half4x4, 1, dequantize_bf16>>; template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q4_0, 2, dequantize_q4_0>>; template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q4_1, 2, dequantize_q4_1>>; template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<DD<block_q5_0, 2, dequantize_q5_0>>; @@ -7987,6 +8111,7 @@ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>; template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>; +template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_bf16_impl<float>>>; template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>; template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; |