diff options
Diffstat (limited to 'ggml')
-rw-r--r-- | ggml/include/ggml.h | 6 | ||||
-rw-r--r-- | ggml/src/ggml-cuda.cu | 12 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/unary.cu | 36 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/unary.cuh | 3 | ||||
-rw-r--r-- | ggml/src/ggml-metal.m | 35 | ||||
-rw-r--r-- | ggml/src/ggml-metal.metal | 131 | ||||
-rw-r--r-- | ggml/src/ggml.c | 91 |
7 files changed, 297 insertions, 17 deletions
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 83dab61b..8980285f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -494,6 +494,7 @@ extern "C" { GGML_OP_GROUP_NORM, GGML_OP_FUSED_RMS_NORM, GGML_OP_FUSED_MUL_UNARY, + GGML_OP_MULTI_ADD, GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID, @@ -930,6 +931,11 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_multi_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_experts); + // dst = a // view(dst, nb1, nb2, nb3, offset) += b // return dst diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 6759e202..e38e9568 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2220,6 +2220,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ADD: ggml_cuda_op_add(ctx, dst); break; + case GGML_OP_MULTI_ADD: + ggml_cuda_op_multi_add(ctx, dst); + break; case GGML_OP_ACC: ggml_cuda_op_acc(ctx, dst); break; @@ -2607,6 +2610,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); #endif } + if (node->op == GGML_OP_MULTI_ADD && node->ne[1] > 1) { + // disable CUDA graphs for batch size > 1 for now. + // Changes in batch size or context size can cause changes to the grid size of some kernels. + use_cuda_graph = false; +#ifndef NDEBUG + GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); +#endif + } if (node->op == GGML_OP_CPY) { // store the copy op parameter which changes with each token. @@ -2927,6 +2938,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_TRANSPOSE: case GGML_OP_NORM: case GGML_OP_ADD: + case GGML_OP_MULTI_ADD: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_RMS_NORM: diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 7bc43d0f..8ffddd6d 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -52,6 +52,25 @@ static __global__ void fused_mul_silu_f32(const float * x, const float * y, floa dst[i] = x[i] * y[i] / (1.0f + expf(-x[i])); } +static __global__ void multi_add_f32(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + int64_t k = ne0*ne1; + if (i >= k) { + return; + } + int i1 = i / ne0; + int i0 = i % ne0; + float * result = (float *)(dst + i1*nb1); + const float * s = (const float *)(src0 + i1*nb01) + i0; + if (nused == 1) { + result[i0] = s[0]; + } else { + float sum = s[0] + s[ne0]; + for (int j = 2; j < nused; ++j) sum += s[j*ne0]; + result[i0] = sum; + } +} + static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -218,6 +237,23 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_ sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k); } +static void multi_add_f32_cuda(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst, cudaStream_t stream) { + int64_t k = ne0 * ne1; + const int num_blocks = (k + CUDA_MULTI_ADD_BLOCK_SIZE - 1) / CUDA_MULTI_ADD_BLOCK_SIZE; + multi_add_f32<<<num_blocks, CUDA_MULTI_ADD_BLOCK_SIZE, 0, stream>>>(nused, ne0, ne1, nb1, nb01, src0, dst); +} + +void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1); + GGML_ASSERT(dst->nb[0] == sizeof(float)); + int nused = dst->op_params[0]; + GGML_ASSERT(nused >= 1); + const char * src0 = (const char *)dst->src[0]->data; + cudaStream_t stream = ctx.stream(); + multi_add_f32_cuda(nused, dst->ne[0], dst->ne[1], dst->nb[1], dst->src[0]->nb[1], src0, (char *)dst->data, stream); +} + void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index d2d478b4..0235a319 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -9,6 +9,7 @@ #define CUDA_HARDSWISH_BLOCK_SIZE 256 #define CUDA_SQR_BLOCK_SIZE 256 #define CUDA_SQRT_BLOCK_SIZE 256 +#define CUDA_MULTI_ADD_BLOCK_SIZE 256 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); @@ -35,3 +36,5 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 8d350aa1..0498be1f 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -39,6 +39,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ADD, GGML_METAL_KERNEL_TYPE_ADD_4, GGML_METAL_KERNEL_TYPE_ADD_ROW, + GGML_METAL_KERNEL_TYPE_MULTI_ADD, + GGML_METAL_KERNEL_TYPE_MULTI_ADD_4, GGML_METAL_KERNEL_TYPE_MUL, GGML_METAL_KERNEL_TYPE_MUL_4, GGML_METAL_KERNEL_TYPE_MUL_ROW, @@ -577,6 +579,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_4, add_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MULTI_ADD, multi_add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MULTI_ADD_4, multi_add_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_4, mul_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); @@ -932,6 +936,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_OP_PERMUTE: case GGML_OP_CONCAT: case GGML_OP_ADD: + case GGML_OP_MULTI_ADD: case GGML_OP_ACC: case GGML_OP_MUL: case GGML_OP_DIV: @@ -1349,6 +1354,36 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } } break; + case GGML_OP_MULTI_ADD: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(dstt == GGML_TYPE_F32); + GGML_ASSERT(ne02 == 1 && ne03 == 1); + GGML_ASSERT(nb0 == sizeof(float) && nb00 == sizeof(float)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + int n_expert = dst->op_params[0]; + GGML_ASSERT(n_expert >= 2); + + id<MTLComputePipelineState> pipeline = nil; + int64_t n = ne0*ne1; + if (ne0%4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD].pipeline; + } + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:3]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:4]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; + [encoder setBytes:&n_expert length:sizeof(n_expert) atIndex:6]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_REPEAT: { id<MTLComputePipelineState> pipeline; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index e8f742fc..89cd412a 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -479,6 +479,44 @@ kernel void kernel_sqr( dst[tpig] = src0[tpig] * src0[tpig]; } +kernel void kernel_multi_add_4( + device const float4 * src0, + device float4 * dst, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant int64_t & nb01, + constant int & n_expert, + uint tpig[[thread_position_in_grid]]) { + + int64_t i0 = tpig % (ne0/4); + int64_t i1 = tpig / (ne0/4); + device float4 * dst_ptr = dst + i1*(nb1/16) + i0; + device const float4 * src_ptr = src0 + i1*(nb01/16) + i0; + float4 sum = src_ptr[0] + src_ptr[ne0/4]; + for (int i = 2; i < n_expert; ++i) sum += src_ptr[i*ne0/4]; + dst_ptr[0] = sum; +} + +kernel void kernel_multi_add( + device const float * src0, + device float * dst, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant int64_t & nb01, + constant int & n_expert, + uint tpig[[thread_position_in_grid]]) { + + int64_t i0 = tpig % ne0; + int64_t i1 = tpig / ne0; + device float * dst_ptr = dst + i1*nb1/4 + i0; + device const float * src_ptr = src0 + i1*nb01/4 + i0; + float sum = src_ptr[0] + src_ptr[ne0]; + for (int i = 2; i < n_expert; ++i) sum += src_ptr[i*ne0]; + dst_ptr[0] = sum; +} + kernel void kernel_sum_rows( device const float * src0, device float * dst, @@ -8197,6 +8235,7 @@ kernel void kernel_mul_mm_id( threadgroup uchar * shared_memory [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], + uint3 ntg3[[threads_per_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { const int32_t i02 = tgpig.z; @@ -8204,25 +8243,87 @@ kernel void kernel_mul_mm_id( device const uchar * src0 = src0s + i02*nb02; - // row indices - threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + uint ntg = ntg3.x * ntg3.y * ntg3.z; + uint n = nei0*nei1; - // TODO: parallelize this loop - int64_t _ne1 = 0; - for (ushort ii1 = 0; ii1 < nei1; ii1++) { - for (ushort ii0 = 0; ii0 < nei0; ii0++) { - int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; - if (id == i02) { - //if (tiitg == 0) { - rowids[_ne1] = ushort2(ii0, ii1); - //} - _ne1++; - } - } - } + //uint npt = (n + ntg - 1) / ntg; + //uint first = tiitg * npt; + //uint last = first + npt <= n ? first + npt : n; + //uint nhave = 0; + //for (uint i = first; i < last; ++i) { + // uint ii0 = i % nei0; + // uint ii1 = i / nei0; + // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + // if (id == i02) ++nhave; + //} + //threadgroup uint * nums = (threadgroup uint *)shared_memory; + //nums[tiitg] = nhave; + //threadgroup_barrier(mem_flags::mem_threadgroup); + + //uint nprev = 0; + //for (uint i = 0; i < tiitg; ++i) nprev += nums[i]; + //int64_t _ne1 = nprev; + //for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i]; + + //threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + //for (uint i = first; i < last; ++i) { + // uint ii0 = i % nei0; + // uint ii1 = i / nei0; + // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + // if (id == i02) rowids[nprev++] = ushort2(ii0, ii1); + //} + + //threadgroup_barrier(mem_flags::mem_threadgroup); + + // + // The following is slightly faster than the commented out version above + // + uint nhave = 0; + for (uint i = tiitg; i < n; i += ntg) { + uint ii0 = i % nei0; + uint ii1 = i / nei0; + int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + if (id == i02) ++nhave; + } + threadgroup uint * nums = (threadgroup uint *)shared_memory; + nums[tiitg] = nhave; threadgroup_barrier(mem_flags::mem_threadgroup); + uint nprev = 0; + for (uint i = 0; i < tiitg; ++i) nprev += nums[i]; + int64_t _ne1 = nprev; + for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i]; + + threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + for (uint i = tiitg; i < n; i += ntg) { + uint ii0 = i % nei0; + uint ii1 = i / nei0; + int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + if (id == i02) rowids[nprev++] = ushort2(ii0, ii1); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // This is the original version that is ridiculously slow. + //// row indices + //threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + + //// TODO: parallelize this loop + //int64_t _ne1 = 0; + //for (ushort ii1 = 0; ii1 < nei1; ii1++) { + // for (ushort ii0 = 0; ii0 < nei0; ii0++) { + // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + // if (id == i02) { + // //if (tiitg == 0) { + // rowids[_ne1] = ushort2(ii0, ii1); + // //} + // _ne1++; + // } + // } + //} + + //threadgroup_barrier(mem_flags::mem_threadgroup); + kernel_mul_mm_id_impl<Dequantizer>( src0, src1, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5570b1fc..39218ff4 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3338,6 +3338,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GROUP_NORM", "FUSED_RMS_NORM", "FUSED_MUL_UNARY", + "MULTI_ADD", "MUL_MAT", "MUL_MAT_ID", @@ -3401,7 +3402,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); +static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3430,6 +3431,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "group_norm(x)", "fused_rms_norm(x)", "fused_mul_unary(x)", + "x1+x2+x3+...", "X*Y", "X[i]*Y", @@ -3493,7 +3495,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); +static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5106,6 +5108,29 @@ struct ggml_tensor * ggml_add_inplace( return ggml_add_impl(ctx, a, b, true); } +// ggml_add + +struct ggml_tensor * ggml_multi_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_experts) { + + bool is_node = false; + + if (n_experts < 1) { + GGML_ABORT("fatal error"); + } + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_MULTI_ADD; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->op_params[0] = n_experts; + + return result; +} + // ggml_add_cast static struct ggml_tensor * ggml_add_cast_impl( @@ -10425,6 +10450,59 @@ static void ggml_compute_forward_add( } } +static void ggml_compute_forward_multi_add_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + struct ggml_tensor * src = dst->src[0]; + + GGML_ASSERT(dst->nb[0] == sizeof(float)); + GGML_ASSERT(src->nb[0] == sizeof(float)); + GGML_ASSERT(ggml_are_same_shape(src, dst)); + GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1); + + const int n_add = dst->op_params[0]; + GGML_ASSERT(n_add > 0); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(dst); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + int64_t ne0 = dst->ne[0]; + + for (int i1 = ir0; i1 < ir1; ++i1) { + + float * dst_ptr = (float *) ((char *) dst->data + i1*dst->nb[1] ); + const float * data = (const float *) ((const char *)src->data + i1*src->nb[1]); + memset(dst_ptr, 0, ne0*sizeof(float)); + for (int j = 0; j < n_add; ++j) { + ggml_vec_add_f32(ne0, dst_ptr, dst_ptr, data + j*ne0); + } + } +} + +static void ggml_compute_forward_multi_add( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + switch (dst->type) { + case GGML_TYPE_F32: { + ggml_compute_forward_multi_add_f32(params, dst); + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_add1 static void ggml_compute_forward_add1_f32( @@ -18202,6 +18280,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_add1(params, tensor); } break; + case GGML_OP_MULTI_ADD: + { + ggml_compute_forward_multi_add(params, tensor); + } break; case GGML_OP_ACC: { ggml_compute_forward_acc(params, tensor); @@ -18947,6 +19029,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: implement } + case GGML_OP_MULTI_ADD: + { + GGML_ABORT("fatal error"); // TODO: implement + } case GGML_OP_CONCAT: { GGML_ABORT("fatal error"); // TODO: implement @@ -19996,6 +20082,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ADD: case GGML_OP_ADD1: case GGML_OP_ACC: + case GGML_OP_MULTI_ADD: { n_tasks = n_threads; } break; |