summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-01-31 15:35:41 +0200
committerGeorgi Gerganov <ggerganov@gmail.com>2024-01-31 15:35:41 +0200
commitefb7bdbbd061d087c788598b97992c653f992ddd (patch)
tree9dc409c591b273446f2ab4e7e166134f4614299a
parent15606309a05ccf7fadbaad5538cb7c32acb1e06b (diff)
metal : add im2col F32 dst support (#5132)
-rw-r--r--ggml-metal.m13
-rw-r--r--ggml-metal.metal33
2 files changed, 39 insertions, 7 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index f8785955..5260ed82 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -135,6 +135,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ROPE_F16,
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
+ GGML_METAL_KERNEL_TYPE_IM2COL_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
GGML_METAL_KERNEL_TYPE_PAD_F32,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@@ -506,6 +507,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
@@ -630,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_OP_ALIBI:
case GGML_OP_ROPE:
case GGML_OP_IM2COL:
+ return true;
+ case GGML_OP_POOL_1D:
+ case GGML_OP_POOL_2D:
+ return false;
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_ARGSORT:
@@ -2015,7 +2021,7 @@ static bool ggml_metal_graph_compute(
{
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
@@ -2023,6 +2029,7 @@ static bool ggml_metal_graph_compute(
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
const int32_t N = src1->ne[is_2D ? 3 : 2];
@@ -2043,8 +2050,8 @@ static bool ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil;
- switch (src0->type) {
- case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
+ switch (dst->type) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
default: GGML_ASSERT(false);
};
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 2614d82e..efed6ad4 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -1775,9 +1775,29 @@ kernel void kernel_rope(
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
-kernel void kernel_im2col_f16(
+typedef void (im2col_t)(
device const float * x,
- device half * dst,
+ device char * dst,
+ constant int32_t & ofs0,
+ constant int32_t & ofs1,
+ constant int32_t & IW,
+ constant int32_t & IH,
+ constant int32_t & CHW,
+ constant int32_t & s0,
+ constant int32_t & s1,
+ constant int32_t & p0,
+ constant int32_t & p1,
+ constant int32_t & d0,
+ constant int32_t & d1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]);
+
+template <typename T>
+kernel void kernel_im2col(
+ device const float * x,
+ device char * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
@@ -1800,14 +1820,19 @@ kernel void kernel_im2col_f16(
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
+ device T * pdst = (device T *) (dst);
+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
- dst[offset_dst] = 0.0f;
+ pdst[offset_dst] = 0.0f;
} else {
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
- dst[offset_dst] = x[offset_src + iih * IW + iiw];
+ pdst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
+template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
+template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
+
kernel void kernel_upscale_f32(
device const char * src0,
device char * dst,