summaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m334
1 files changed, 288 insertions, 46 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index f9bd69dc..1dcfa6ed 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -102,6 +102,21 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -140,6 +155,7 @@ struct ggml_metal_context {
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
+ GGML_METAL_DECL_KERNEL(cpy_f16_f32);
GGML_METAL_DECL_KERNEL(concat);
GGML_METAL_DECL_KERNEL(sqr);
GGML_METAL_DECL_KERNEL(sum_rows);
@@ -177,6 +193,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
} else {
char* buffer2 = malloc(len+1);
+ va_end(args);
+ va_start(args, format);
vsnprintf(buffer2, len+1, format, args);
buffer2[len] = 0;
ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
@@ -352,6 +370,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -392,6 +425,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
+ GGML_METAL_ADD_KERNEL(cpy_f16_f32);
GGML_METAL_ADD_KERNEL(concat);
GGML_METAL_ADD_KERNEL(sqr);
GGML_METAL_ADD_KERNEL(sum_rows);
@@ -452,6 +486,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -492,6 +541,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
+ GGML_METAL_DEL_KERNEL(cpy_f16_f32);
GGML_METAL_DEL_KERNEL(concat);
GGML_METAL_DEL_KERNEL(sqr);
GGML_METAL_DEL_KERNEL(sum_rows);
@@ -803,8 +853,9 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
- case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_GET_ROWS:
case GGML_OP_CONCAT:
case GGML_OP_ADD:
case GGML_OP_MUL:
@@ -819,14 +870,38 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
case GGML_OP_ROPE:
case GGML_OP_IM2COL:
case GGML_OP_ARGSORT:
- case GGML_OP_DUP:
- case GGML_OP_CPY:
- case GGML_OP_CONT:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return true;
+ case GGML_OP_CPY:
+ case GGML_OP_DUP:
+ case GGML_OP_CONT:
+ {
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F32:
+ switch (op->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ return true;
+ default:
+ return false;
+ }
+ case GGML_TYPE_F16:
+ switch (op->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ return true;
+ default:
+ return false;
+ }
+ default:
+ return false;
+ };
+ }
case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_GET_ROWS:
{
return op->ne[0] % 4 == 0;
}
@@ -1001,34 +1076,37 @@ void ggml_metal_graph_compute(
case GGML_OP_MUL:
case GGML_OP_DIV:
{
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
-
bool bcast_row = false;
int64_t nb = ne00;
- if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
+ id<MTLComputePipelineState> pipeline = nil;
+
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
// src1 is a row
GGML_ASSERT(ne11 == 1);
nb = ne00 / 4;
switch (dst->op) {
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
default: GGML_ASSERT(false);
}
bcast_row = true;
} else {
switch (dst->op) {
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
default: GGML_ASSERT(false);
}
}
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@@ -1063,7 +1141,7 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
- const int nth = MIN(1024, ne0);
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
@@ -1193,7 +1271,11 @@ void ggml_metal_graph_compute(
const float scale = ((float *) dst->op_params)[0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ if (id_src1) {
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ }
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
@@ -1444,7 +1526,7 @@ void ggml_metal_graph_compute(
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
- int64_t ny = (ne11 + nrows - 1)/nrows;
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
@@ -1456,7 +1538,7 @@ void ggml_metal_graph_compute(
GGML_ASSERT(src0t == GGML_TYPE_I32);
- const int n_as = ne00;
+ const int n_as = ((int32_t *) dst->op_params)[1];
// TODO: make this more general
GGML_ASSERT(n_as <= 8);
@@ -1488,14 +1570,22 @@ void ggml_metal_graph_compute(
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
- int ne11_mm_min = 0;
+ int ne11_mm_min = 1;
const int idx = ((int32_t *) dst->op_params)[0];
+ // batch size
+ GGML_ASSERT(ne01 == ne11);
+
+ const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
- ne11 > ne11_mm_min) {
+ // !!!
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
+ // indirect matrix multiplication
+ // !!!
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
switch (src2->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
@@ -1514,19 +1604,22 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
- [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
// TODO: how to make this an array? read Metal docs
for (int j = 0; j < n_as; ++j) {
struct ggml_tensor * src_cur = dst->src[2 + j];
@@ -1534,11 +1627,157 @@ void ggml_metal_graph_compute(
size_t offs_src_cur = 0;
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
}
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+
+ // TODO: processing one row at a time (ne11 -> 1) is not efficient
+ [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+ } else {
+ int nth0 = 32;
+ int nth1 = 1;
+ int nrows = 1;
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+ // use custom matrix x vector kernel
+ switch (src2t) {
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
+ } break;
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ nth0 = 32;
+ nth1 = 1;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
+ } break;
+ case GGML_TYPE_Q4_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
+ } break;
+ case GGML_TYPE_Q4_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
+ } break;
+ case GGML_TYPE_Q5_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
+ } break;
+ case GGML_TYPE_Q5_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
+ } break;
+ case GGML_TYPE_Q8_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
+ } break;
+ case GGML_TYPE_Q2_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
+ } break;
+ case GGML_TYPE_Q3_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
+ } break;
+ case GGML_TYPE_Q4_K:
+ {
+ nth0 = 4; //1;
+ nth1 = 8; //32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
+ } break;
+ case GGML_TYPE_Q5_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
+ } break;
+ case GGML_TYPE_Q6_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
+ } break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
+ GGML_ASSERT(false && "not implemented");
+ }
+ };
+
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
+ // TODO: how to make this an array? read Metal docs
+ for (int j = 0; j < n_as; ++j) {
+ struct ggml_tensor * src_cur = dst->src[2 + j];
+
+ size_t offs_src_cur = 0;
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
+
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
+ }
+
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
+ src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_Q3_K) {
+#ifdef GGML_QKK_64
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#else
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#endif
+ }
+ else if (src2t == GGML_TYPE_Q5_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_Q6_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ } else {
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
}
} break;
case GGML_OP_GET_ROWS:
@@ -1559,16 +1798,19 @@ void ggml_metal_graph_compute(
default: GGML_ASSERT(false && "not implemented");
}
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
-
- const int64_t n = ggml_nelements(src1);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
case GGML_OP_RMS_NORM:
{
@@ -1813,7 +2055,7 @@ void ggml_metal_graph_compute(
{
switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
- case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;