diff options
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index f8785955..5260ed82 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -135,6 +135,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ROPE_F16, GGML_METAL_KERNEL_TYPE_ALIBI_F32, GGML_METAL_KERNEL_TYPE_IM2COL_F16, + GGML_METAL_KERNEL_TYPE_IM2COL_F32, GGML_METAL_KERNEL_TYPE_UPSCALE_F32, GGML_METAL_KERNEL_TYPE_PAD_F32, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, @@ -506,6 +507,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); @@ -630,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_ALIBI: case GGML_OP_ROPE: case GGML_OP_IM2COL: + return true; + case GGML_OP_POOL_1D: + case GGML_OP_POOL_2D: + return false; case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_ARGSORT: @@ -2015,7 +2021,7 @@ static bool ggml_metal_graph_compute( { GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; @@ -2023,6 +2029,7 @@ static bool ggml_metal_graph_compute( const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; const int32_t N = src1->ne[is_2D ? 3 : 2]; @@ -2043,8 +2050,8 @@ static bool ggml_metal_graph_compute( id<MTLComputePipelineState> pipeline = nil; - switch (src0->type) { - case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break; + switch (dst->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; default: GGML_ASSERT(false); }; |