summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml-metal.m29
-rw-r--r--ggml-metal.metal22
2 files changed, 51 insertions, 0 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index d6f2df94..6c630f7b 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -30,10 +30,13 @@ struct ggml_metal_kernel {
enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ADD,
+ GGML_METAL_KERNEL_TYPE_ADD_4,
GGML_METAL_KERNEL_TYPE_ADD_ROW,
GGML_METAL_KERNEL_TYPE_MUL,
+ GGML_METAL_KERNEL_TYPE_MUL_4,
GGML_METAL_KERNEL_TYPE_MUL_ROW,
GGML_METAL_KERNEL_TYPE_DIV,
+ GGML_METAL_KERNEL_TYPE_DIV_4,
GGML_METAL_KERNEL_TYPE_DIV_ROW,
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
@@ -496,10 +499,13 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
// simd_sum and simd_max requires MTLGPUFamilyApple7
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_4, add_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_4, mul_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_4, div_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
@@ -1100,6 +1106,29 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
break;
}
+ else if (ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && ggml_is_contiguous(dst) &&
+ dst->src[0]->ne[0] == dst->src[1]->ne[0] && dst->src[0]->ne[0] == dst->ne[0] &&
+ dst->src[0]->ne[1] == dst->src[1]->ne[1] && dst->src[0]->ne[1] == dst->ne[1] &&
+ dst->src[0]->ne[2] == dst->src[1]->ne[2] && dst->src[0]->ne[2] == dst->ne[2] &&
+ dst->src[0]->ne[3] == dst->src[1]->ne[3] && ggml_nelements(dst)%4 == 0) {
+
+ switch (dst->op) {
+ case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_4].pipeline; break;
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_4].pipeline; break;
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_4].pipeline; break;
+ default: GGML_ASSERT(false);
+ }
+
+ int64_t n = ggml_nelements(dst)/4;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ break;
+ }
else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(src0));
diff --git a/ggml-metal.metal b/ggml-metal.metal
index ab63b096..c9439727 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -225,6 +225,13 @@ kernel void kernel_add_row(
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig % nb];
}
+kernel void kernel_add_4(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] + src1[tpig];
+}
kernel void kernel_mul_row(
device const float4 * src0,
@@ -235,6 +242,14 @@ kernel void kernel_mul_row(
dst[tpig] = src0[tpig] * src1[tpig % nb];
}
+kernel void kernel_mul_4(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * src1[tpig];
+}
+
kernel void kernel_div_row(
device const float4 * src0,
device const float4 * src1,
@@ -243,6 +258,13 @@ kernel void kernel_div_row(
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] / src1[tpig % nb];
}
+kernel void kernel_div_4(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] / src1[tpig];
+}
kernel void kernel_scale(
device const float * src0,