diff options
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 35 |
1 files changed, 25 insertions, 10 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index 4f3f14e2..3e3be98c 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -66,6 +66,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(soft_max_4); GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(diag_mask_inf_8); + GGML_METAL_DECL_KERNEL(get_rows_f32); GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); @@ -145,7 +146,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { ctx->n_buffers = 0; ctx->concur_list_len = 0; - ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT); + ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); #ifdef GGML_SWIFT // load the default.metallib file @@ -175,7 +176,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"]; NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; - NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; + NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]); NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; @@ -224,6 +225,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(soft_max_4); GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(diag_mask_inf_8); + GGML_METAL_ADD_KERNEL(get_rows_f32); GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); @@ -293,7 +295,9 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(gelu); GGML_METAL_DEL_KERNEL(soft_max); GGML_METAL_DEL_KERNEL(soft_max_4); + GGML_METAL_DEL_KERNEL(diag_mask_inf); GGML_METAL_DEL_KERNEL(diag_mask_inf_8); + GGML_METAL_DEL_KERNEL(get_rows_f32); GGML_METAL_DEL_KERNEL(get_rows_f16); GGML_METAL_DEL_KERNEL(get_rows_q4_0); GGML_METAL_DEL_KERNEL(get_rows_q4_1); @@ -386,6 +390,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru for (int i = 0; i < ctx->n_buffers; ++i) { const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; + //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name); if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) { *offs = (size_t) ioffs; @@ -723,6 +728,7 @@ void ggml_metal_graph_compute( case GGML_OP_ADD: { GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); // utilize float4 GGML_ASSERT(ne00 % 4 == 0); @@ -730,6 +736,7 @@ void ggml_metal_graph_compute( if (ggml_nelements(src1) == ne10) { // src1 is a row + GGML_ASSERT(ne11 == 1); [encoder setComputePipelineState:ctx->pipeline_add_row]; } else { [encoder setComputePipelineState:ctx->pipeline_add]; @@ -746,6 +753,7 @@ void ggml_metal_graph_compute( case GGML_OP_MUL: { GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); // utilize float4 GGML_ASSERT(ne00 % 4 == 0); @@ -753,6 +761,7 @@ void ggml_metal_graph_compute( if (ggml_nelements(src1) == ne10) { // src1 is a row + GGML_ASSERT(ne11 == 1); [encoder setComputePipelineState:ctx->pipeline_mul_row]; } else { [encoder setComputePipelineState:ctx->pipeline_mul]; @@ -768,6 +777,8 @@ void ggml_metal_graph_compute( } break; case GGML_OP_SCALE: { + GGML_ASSERT(ggml_is_contiguous(src0)); + const float scale = *(const float *) src1->data; [encoder setComputePipelineState:ctx->pipeline_scale]; @@ -867,8 +878,8 @@ void ggml_metal_graph_compute( // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if (ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && + if (!ggml_is_transposed(src0) && + !ggml_is_transposed(src1) && src1t == GGML_TYPE_F32 && [ctx->device supportsFamily:MTLGPUFamilyApple7] && ne00%32 == 0 && @@ -893,9 +904,12 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { @@ -1045,6 +1059,7 @@ void ggml_metal_graph_compute( case GGML_OP_GET_ROWS: { switch (src0->type) { + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; @@ -1060,9 +1075,9 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5]; const int64_t n = ggml_nelements(src1); |