summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-01-03 11:35:46 +0200
committerGeorgi Gerganov <ggerganov@gmail.com>2024-01-03 14:38:38 +0200
commit289313716ff7ccf6aee284f686a0fe8cbc7714af (patch)
tree64bfab202ca257db34b29cb08d946f8f27a0f5c2
parentab62fc3e5520f5a143c36cb23c269f11aa4dafd6 (diff)
metal : add kernel_get_rows_i32
ggml-ci
-rw-r--r--ggml-metal.m4
-rw-r--r--ggml-metal.metal29
2 files changed, 33 insertions, 0 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 7a369b55..7aa92c14 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -87,6 +87,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
+ GGML_METAL_DECL_KERNEL(get_rows_i32);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(group_norm);
GGML_METAL_DECL_KERNEL(norm);
@@ -377,6 +378,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
+ GGML_METAL_ADD_KERNEL(get_rows_i32);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(group_norm);
GGML_METAL_ADD_KERNEL(norm);
@@ -499,6 +501,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
+ GGML_METAL_DEL_KERNEL(get_rows_i32);
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(group_norm);
GGML_METAL_DEL_KERNEL(norm);
@@ -1978,6 +1981,7 @@ void ggml_metal_graph_compute(
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
+ case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
default: GGML_ASSERT(false && "not implemented");
}
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 9aa7b502..a7d3f9ef 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -3829,6 +3829,35 @@ kernel void kernel_get_rows_f16(
}
}
+kernel void kernel_get_rows_i32(
+ device const void * src0,
+ device const char * src1,
+ device int32_t * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+ }
+}
+
+
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
#define BLOCK_SIZE_K 32