summaryrefslogtreecommitdiff
path: root/ggml-metal.metal
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-11-13 16:55:52 +0200
committerGitHub <noreply@github.com>2023-11-13 16:55:52 +0200
commit3d68f364f15778dc326f5024f2e5af1ad6dfddef (patch)
treec0c11d150ba56b4f646261790728622efa30d8a1 /ggml-metal.metal
parentc049b37d7baf558944501705b91ac89b26ee3e41 (diff)
ggml : sync (im2col, GPU conv, 32-bit arm compat) (#4060)
ggml-ci
Diffstat (limited to 'ggml-metal.metal')
-rw-r--r--ggml-metal.metal108
1 files changed, 107 insertions, 1 deletions
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 7c35f23a..5d1357cd 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32(
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F32_F32;
@@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
}
}
+#define N_F16_F16 4
+
+kernel void kernel_mul_mv_f16_f16(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+
+ const int64_t r0 = tgpig.x;
+ const int64_t rb = tgpig.y*N_F16_F16;
+ const int64_t im = tgpig.z;
+
+ device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+ if (ne00 < 128) {
+ for (int row = 0; row < N_F16_F16; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00; i += 32) {
+ sumf += (half) x[i] * (half) y[i];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+ } else {
+ device const half4 * x4 = (device const half4 *)x;
+ for (int row = 0; row < N_F16_F16; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
+ device const half4 * y4 = (device const half4 *) y;
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+ }
+}
+
kernel void kernel_mul_mv_f16_f32_1row(
device const char * src0,
device const char * src1,
@@ -1229,6 +1302,39 @@ kernel void kernel_rope(
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
+kernel void kernel_im2col_f16(
+ device const float * x,
+ device half * dst,
+ constant int32_t & ofs0,
+ constant int32_t & ofs1,
+ constant int32_t & IW,
+ constant int32_t & IH,
+ constant int32_t & CHW,
+ constant int32_t & s0,
+ constant int32_t & s1,
+ constant int32_t & p0,
+ constant int32_t & p1,
+ constant int32_t & d0,
+ constant int32_t & d1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
+ const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
+
+ const int32_t offset_dst =
+ (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
+ (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
+
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+ dst[offset_dst] = 0.0f;
+ } else {
+ const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
+ }
+}
+
kernel void kernel_cpy_f16_f16(
device const half * src0,
device half * dst,