diff options
Diffstat (limited to 'ggml-metal.m')
-rw-r--r-- | ggml-metal.m | 106 |
1 files changed, 90 insertions, 16 deletions
diff --git a/ggml-metal.m b/ggml-metal.m index c2cda0bf..3d22b0b2 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -86,6 +86,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mv_f32_f32); + GGML_METAL_DECL_KERNEL(mul_mv_f16_f16); GGML_METAL_DECL_KERNEL(mul_mv_f16_f32); GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row); GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); @@ -114,6 +115,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(rope_f32); GGML_METAL_DECL_KERNEL(rope_f16); GGML_METAL_DECL_KERNEL(alibi_f32); + GGML_METAL_DECL_KERNEL(im2col_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); GGML_METAL_DECL_KERNEL(cpy_f16_f16); @@ -126,7 +128,7 @@ struct ggml_metal_context { // MSL code // TODO: move the contents here when ready // for now it is easier to work in a separate file -static NSString * const msl_library_source = @"see metal.metal"; +//static NSString * const msl_library_source = @"see metal.metal"; // Here to assist with NSBundle Path Hack @interface GGMLMetalClass : NSObject @@ -142,7 +144,8 @@ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_dat ggml_metal_log_user_data = user_data; } -static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){ +GGML_ATTRIBUTE_FORMAT(2, 3) +static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ if (ggml_metal_log_callback != NULL) { va_list args; va_start(args, format); @@ -210,7 +213,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { } else { GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); - NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; + NSString * sourcePath; + NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"]; + if (ggmlMetalPathResources) { + sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"]; + } else { + sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; + } if (sourcePath == nil) { GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); sourcePath = @"ggml-metal.metal"; @@ -281,6 +290,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mv_f32_f32); + GGML_METAL_ADD_KERNEL(mul_mv_f16_f16); GGML_METAL_ADD_KERNEL(mul_mv_f16_f32); GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row); GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); @@ -311,6 +321,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(rope_f32); GGML_METAL_ADD_KERNEL(rope_f16); GGML_METAL_ADD_KERNEL(alibi_f32); + GGML_METAL_ADD_KERNEL(im2col_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); GGML_METAL_ADD_KERNEL(cpy_f16_f16); @@ -329,7 +340,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { if ([ctx->device supportsFamily:i]) { - GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i); + GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); break; } } @@ -380,6 +391,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(rms_norm); GGML_METAL_DEL_KERNEL(norm); GGML_METAL_DEL_KERNEL(mul_mv_f32_f32); + GGML_METAL_DEL_KERNEL(mul_mv_f16_f16); GGML_METAL_DEL_KERNEL(mul_mv_f16_f32); GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row); GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); @@ -410,6 +422,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(rope_f32); GGML_METAL_DEL_KERNEL(rope_f16); GGML_METAL_DEL_KERNEL(alibi_f32); + GGML_METAL_DEL_KERNEL(im2col_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f32); GGML_METAL_DEL_KERNEL(cpy_f16_f16); @@ -467,6 +480,10 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru const int64_t tsize = ggml_nbytes(t); + if (t->buffer && t->buffer->backend && t->buffer->backend->context) { + ctx = t->buffer->backend->context; + } + // find the view that contains the tensor fully for (int i = 0; i < ctx->n_buffers; ++i) { const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; @@ -567,7 +584,7 @@ bool ggml_metal_add_buffer( ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) { - GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__); + GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); } else { GGML_METAL_LOG_INFO("\n"); } @@ -1024,7 +1041,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setThreadgroupMemoryLength:MAX(16, nth/32*sizeof(float)) atIndex:0]; + [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -1133,6 +1150,7 @@ void ggml_metal_graph_compute( switch (src0t) { case GGML_TYPE_F32: { + GGML_ASSERT(src1t == GGML_TYPE_F32); [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32]; nrows = 4; } break; @@ -1140,13 +1158,18 @@ void ggml_metal_graph_compute( { nth0 = 32; nth1 = 1; - if (ne11 * ne12 < 4) { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row]; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4]; - nrows = ne11; + if (src1t == GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row]; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4]; + nrows = ne11; + } else { + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; + nrows = 4; + } } else { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16]; nrows = 4; } } break; @@ -1336,7 +1359,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0]; + [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0]; const int64_t nrows = ggml_nrows(src0); @@ -1355,7 +1378,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:MAX(16, nth*sizeof(float)) atIndex:0]; + [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; const int64_t nrows = ggml_nrows(src0); @@ -1410,8 +1433,7 @@ void ggml_metal_graph_compute( const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[3]; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); @@ -1459,6 +1481,58 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; + case GGML_OP_IM2COL: + { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int32_t N = src1->ne[is_2D ? 3 : 2]; + const int32_t IC = src1->ne[is_2D ? 2 : 1]; + const int32_t IH = is_2D ? src1->ne[1] : 1; + const int32_t IW = src1->ne[0]; + + const int32_t KH = is_2D ? src0->ne[1] : 1; + const int32_t KW = src0->ne[0]; + + const int32_t OH = is_2D ? dst->ne[2] : 1; + const int32_t OW = dst->ne[1]; + + const int32_t CHW = IC * KH * KW; + + const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; + const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + + switch (src0->type) { + case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break; + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break; + default: GGML_ASSERT(false); + }; + + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; + [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; + [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; + [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; + [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; + [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; + [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; + [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; + [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; + [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; + + [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: |