summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-10-31 12:05:27 +0100
committerGitHub <noreply@github.com>2024-10-31 12:05:27 +0100
commit52874c5d21819bd63cc4c500f2fb1be435d16b5e (patch)
treebff705dd887d124958d1ad3c224de9ee9732de1a
parent5ad6439486e5bfdd8e34213a36beb56b74842bbe (diff)
Faster MoE inference (#112)
* multi_sdd: WIP * multi_sdd: CPU works * multi_add: CUDA * multi_add: simplify * multi_add: Metal * Metal: speed up mul_mat_id For the Granite-1B MoE model PP-512 goes from 156 t/s to 890 t/s, so nearly a 6X speedup! --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--ggml/include/ggml.h6
-rw-r--r--ggml/src/ggml-cuda.cu12
-rw-r--r--ggml/src/ggml-cuda/unary.cu36
-rw-r--r--ggml/src/ggml-cuda/unary.cuh3
-rw-r--r--ggml/src/ggml-metal.m35
-rw-r--r--ggml/src/ggml-metal.metal131
-rw-r--r--ggml/src/ggml.c91
-rw-r--r--src/llama.cpp53
8 files changed, 333 insertions, 34 deletions
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 83dab61b..8980285f 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -494,6 +494,7 @@ extern "C" {
GGML_OP_GROUP_NORM,
GGML_OP_FUSED_RMS_NORM,
GGML_OP_FUSED_MUL_UNARY,
+ GGML_OP_MULTI_ADD,
GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID,
@@ -930,6 +931,11 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
+ GGML_API struct ggml_tensor * ggml_multi_add(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_experts);
+
// dst = a
// view(dst, nb1, nb2, nb3, offset) += b
// return dst
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 6759e202..e38e9568 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2220,6 +2220,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ADD:
ggml_cuda_op_add(ctx, dst);
break;
+ case GGML_OP_MULTI_ADD:
+ ggml_cuda_op_multi_add(ctx, dst);
+ break;
case GGML_OP_ACC:
ggml_cuda_op_acc(ctx, dst);
break;
@@ -2607,6 +2610,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
#endif
}
+ if (node->op == GGML_OP_MULTI_ADD && node->ne[1] > 1) {
+ // disable CUDA graphs for batch size > 1 for now.
+ // Changes in batch size or context size can cause changes to the grid size of some kernels.
+ use_cuda_graph = false;
+#ifndef NDEBUG
+ GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
+#endif
+ }
if (node->op == GGML_OP_CPY) {
// store the copy op parameter which changes with each token.
@@ -2927,6 +2938,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_TRANSPOSE:
case GGML_OP_NORM:
case GGML_OP_ADD:
+ case GGML_OP_MULTI_ADD:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_RMS_NORM:
diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu
index 7bc43d0f..8ffddd6d 100644
--- a/ggml/src/ggml-cuda/unary.cu
+++ b/ggml/src/ggml-cuda/unary.cu
@@ -52,6 +52,25 @@ static __global__ void fused_mul_silu_f32(const float * x, const float * y, floa
dst[i] = x[i] * y[i] / (1.0f + expf(-x[i]));
}
+static __global__ void multi_add_f32(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst) {
+ const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
+ int64_t k = ne0*ne1;
+ if (i >= k) {
+ return;
+ }
+ int i1 = i / ne0;
+ int i0 = i % ne0;
+ float * result = (float *)(dst + i1*nb1);
+ const float * s = (const float *)(src0 + i1*nb01) + i0;
+ if (nused == 1) {
+ result[i0] = s[0];
+ } else {
+ float sum = s[0] + s[ne0];
+ for (int j = 2; j < nused; ++j) sum += s[j*ne0];
+ result[i0] = sum;
+ }
+}
+
static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -218,6 +237,23 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_
sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
+static void multi_add_f32_cuda(int nused, int64_t ne0, int64_t ne1, int64_t nb1, int64_t nb01, const char * src0, char * dst, cudaStream_t stream) {
+ int64_t k = ne0 * ne1;
+ const int num_blocks = (k + CUDA_MULTI_ADD_BLOCK_SIZE - 1) / CUDA_MULTI_ADD_BLOCK_SIZE;
+ multi_add_f32<<<num_blocks, CUDA_MULTI_ADD_BLOCK_SIZE, 0, stream>>>(nused, ne0, ne1, nb1, nb01, src0, dst);
+}
+
+void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1);
+ GGML_ASSERT(dst->nb[0] == sizeof(float));
+ int nused = dst->op_params[0];
+ GGML_ASSERT(nused >= 1);
+ const char * src0 = (const char *)dst->src[0]->data;
+ cudaStream_t stream = ctx.stream();
+ multi_add_f32_cuda(nused, dst->ne[0], dst->ne[1], dst->nb[1], dst->src[0]->nb[1], src0, (char *)dst->data, stream);
+}
+
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh
index d2d478b4..0235a319 100644
--- a/ggml/src/ggml-cuda/unary.cuh
+++ b/ggml/src/ggml-cuda/unary.cuh
@@ -9,6 +9,7 @@
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_SQRT_BLOCK_SIZE 256
+#define CUDA_MULTI_ADD_BLOCK_SIZE 256
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -35,3 +36,5 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index 8d350aa1..0498be1f 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -39,6 +39,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ADD,
GGML_METAL_KERNEL_TYPE_ADD_4,
GGML_METAL_KERNEL_TYPE_ADD_ROW,
+ GGML_METAL_KERNEL_TYPE_MULTI_ADD,
+ GGML_METAL_KERNEL_TYPE_MULTI_ADD_4,
GGML_METAL_KERNEL_TYPE_MUL,
GGML_METAL_KERNEL_TYPE_MUL_4,
GGML_METAL_KERNEL_TYPE_MUL_ROW,
@@ -577,6 +579,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_4, add_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MULTI_ADD, multi_add, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MULTI_ADD_4, multi_add_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_4, mul_4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
@@ -932,6 +936,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
case GGML_OP_PERMUTE:
case GGML_OP_CONCAT:
case GGML_OP_ADD:
+ case GGML_OP_MULTI_ADD:
case GGML_OP_ACC:
case GGML_OP_MUL:
case GGML_OP_DIV:
@@ -1349,6 +1354,36 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
} break;
+ case GGML_OP_MULTI_ADD:
+ {
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
+ GGML_ASSERT(ne02 == 1 && ne03 == 1);
+ GGML_ASSERT(nb0 == sizeof(float) && nb00 == sizeof(float));
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ int n_expert = dst->op_params[0];
+ GGML_ASSERT(n_expert >= 2);
+
+ id<MTLComputePipelineState> pipeline = nil;
+ int64_t n = ne0*ne1;
+ if (ne0%4 == 0) {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD_4].pipeline;
+ n /= 4;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD].pipeline;
+ }
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:3];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:4];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
+ [encoder setBytes:&n_expert length:sizeof(n_expert) atIndex:6];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
case GGML_OP_REPEAT:
{
id<MTLComputePipelineState> pipeline;
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index e8f742fc..89cd412a 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -479,6 +479,44 @@ kernel void kernel_sqr(
dst[tpig] = src0[tpig] * src0[tpig];
}
+kernel void kernel_multi_add_4(
+ device const float4 * src0,
+ device float4 * dst,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant int64_t & nb01,
+ constant int & n_expert,
+ uint tpig[[thread_position_in_grid]]) {
+
+ int64_t i0 = tpig % (ne0/4);
+ int64_t i1 = tpig / (ne0/4);
+ device float4 * dst_ptr = dst + i1*(nb1/16) + i0;
+ device const float4 * src_ptr = src0 + i1*(nb01/16) + i0;
+ float4 sum = src_ptr[0] + src_ptr[ne0/4];
+ for (int i = 2; i < n_expert; ++i) sum += src_ptr[i*ne0/4];
+ dst_ptr[0] = sum;
+}
+
+kernel void kernel_multi_add(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant int64_t & nb01,
+ constant int & n_expert,
+ uint tpig[[thread_position_in_grid]]) {
+
+ int64_t i0 = tpig % ne0;
+ int64_t i1 = tpig / ne0;
+ device float * dst_ptr = dst + i1*nb1/4 + i0;
+ device const float * src_ptr = src0 + i1*nb01/4 + i0;
+ float sum = src_ptr[0] + src_ptr[ne0];
+ for (int i = 2; i < n_expert; ++i) sum += src_ptr[i*ne0];
+ dst_ptr[0] = sum;
+}
+
kernel void kernel_sum_rows(
device const float * src0,
device float * dst,
@@ -8197,6 +8235,7 @@ kernel void kernel_mul_mm_id(
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
+ uint3 ntg3[[threads_per_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int32_t i02 = tgpig.z;
@@ -8204,25 +8243,87 @@ kernel void kernel_mul_mm_id(
device const uchar * src0 = src0s + i02*nb02;
- // row indices
- threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
+ uint ntg = ntg3.x * ntg3.y * ntg3.z;
+ uint n = nei0*nei1;
- // TODO: parallelize this loop
- int64_t _ne1 = 0;
- for (ushort ii1 = 0; ii1 < nei1; ii1++) {
- for (ushort ii0 = 0; ii0 < nei0; ii0++) {
- int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
- if (id == i02) {
- //if (tiitg == 0) {
- rowids[_ne1] = ushort2(ii0, ii1);
- //}
- _ne1++;
- }
- }
- }
+ //uint npt = (n + ntg - 1) / ntg;
+ //uint first = tiitg * npt;
+ //uint last = first + npt <= n ? first + npt : n;
+ //uint nhave = 0;
+ //for (uint i = first; i < last; ++i) {
+ // uint ii0 = i % nei0;
+ // uint ii1 = i / nei0;
+ // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
+ // if (id == i02) ++nhave;
+ //}
+ //threadgroup uint * nums = (threadgroup uint *)shared_memory;
+ //nums[tiitg] = nhave;
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ //uint nprev = 0;
+ //for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
+ //int64_t _ne1 = nprev;
+ //for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];
+
+ //threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
+ //for (uint i = first; i < last; ++i) {
+ // uint ii0 = i % nei0;
+ // uint ii1 = i / nei0;
+ // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
+ // if (id == i02) rowids[nprev++] = ushort2(ii0, ii1);
+ //}
+
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ //
+ // The following is slightly faster than the commented out version above
+ //
+ uint nhave = 0;
+ for (uint i = tiitg; i < n; i += ntg) {
+ uint ii0 = i % nei0;
+ uint ii1 = i / nei0;
+ int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
+ if (id == i02) ++nhave;
+ }
+ threadgroup uint * nums = (threadgroup uint *)shared_memory;
+ nums[tiitg] = nhave;
threadgroup_barrier(mem_flags::mem_threadgroup);
+ uint nprev = 0;
+ for (uint i = 0; i < tiitg; ++i) nprev += nums[i];
+ int64_t _ne1 = nprev;
+ for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i];
+
+ threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
+ for (uint i = tiitg; i < n; i += ntg) {
+ uint ii0 = i % nei0;
+ uint ii1 = i / nei0;
+ int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
+ if (id == i02) rowids[nprev++] = ushort2(ii0, ii1);
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // This is the original version that is ridiculously slow.
+ //// row indices
+ //threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
+
+ //// TODO: parallelize this loop
+ //int64_t _ne1 = 0;
+ //for (ushort ii1 = 0; ii1 < nei1; ii1++) {
+ // for (ushort ii0 = 0; ii0 < nei0; ii0++) {
+ // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
+ // if (id == i02) {
+ // //if (tiitg == 0) {
+ // rowids[_ne1] = ushort2(ii0, ii1);
+ // //}
+ // _ne1++;
+ // }
+ // }
+ //}
+
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
+
kernel_mul_mm_id_impl<Dequantizer>(
src0,
src1,
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 5570b1fc..39218ff4 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -3338,6 +3338,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GROUP_NORM",
"FUSED_RMS_NORM",
"FUSED_MUL_UNARY",
+ "MULTI_ADD",
"MUL_MAT",
"MUL_MAT_ID",
@@ -3401,7 +3402,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
+static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -3430,6 +3431,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"group_norm(x)",
"fused_rms_norm(x)",
"fused_mul_unary(x)",
+ "x1+x2+x3+...",
"X*Y",
"X[i]*Y",
@@ -3493,7 +3495,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
+static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -5106,6 +5108,29 @@ struct ggml_tensor * ggml_add_inplace(
return ggml_add_impl(ctx, a, b, true);
}
+// ggml_add
+
+struct ggml_tensor * ggml_multi_add(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_experts) {
+
+ bool is_node = false;
+
+ if (n_experts < 1) {
+ GGML_ABORT("fatal error");
+ }
+
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_MULTI_ADD;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->op_params[0] = n_experts;
+
+ return result;
+}
+
// ggml_add_cast
static struct ggml_tensor * ggml_add_cast_impl(
@@ -10425,6 +10450,59 @@ static void ggml_compute_forward_add(
}
}
+static void ggml_compute_forward_multi_add_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ struct ggml_tensor * src = dst->src[0];
+
+ GGML_ASSERT(dst->nb[0] == sizeof(float));
+ GGML_ASSERT(src->nb[0] == sizeof(float));
+ GGML_ASSERT(ggml_are_same_shape(src, dst));
+ GGML_ASSERT(dst->ne[2] == 1 && dst->ne[3] == 1);
+
+ const int n_add = dst->op_params[0];
+ GGML_ASSERT(n_add > 0);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(dst);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ int64_t ne0 = dst->ne[0];
+
+ for (int i1 = ir0; i1 < ir1; ++i1) {
+
+ float * dst_ptr = (float *) ((char *) dst->data + i1*dst->nb[1] );
+ const float * data = (const float *) ((const char *)src->data + i1*src->nb[1]);
+ memset(dst_ptr, 0, ne0*sizeof(float));
+ for (int j = 0; j < n_add; ++j) {
+ ggml_vec_add_f32(ne0, dst_ptr, dst_ptr, data + j*ne0);
+ }
+ }
+}
+
+static void ggml_compute_forward_multi_add(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ switch (dst->type) {
+ case GGML_TYPE_F32: {
+ ggml_compute_forward_multi_add_f32(params, dst);
+ } break;
+ default: {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_add1
static void ggml_compute_forward_add1_f32(
@@ -18202,6 +18280,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_add1(params, tensor);
} break;
+ case GGML_OP_MULTI_ADD:
+ {
+ ggml_compute_forward_multi_add(params, tensor);
+ } break;
case GGML_OP_ACC:
{
ggml_compute_forward_acc(params, tensor);
@@ -18947,6 +19029,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ABORT("fatal error"); // TODO: implement
}
+ case GGML_OP_MULTI_ADD:
+ {
+ GGML_ABORT("fatal error"); // TODO: implement
+ }
case GGML_OP_CONCAT:
{
GGML_ABORT("fatal error"); // TODO: implement
@@ -19996,6 +20082,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_ACC:
+ case GGML_OP_MULTI_ADD:
{
n_tasks = n_threads;
} break;
diff --git a/src/llama.cpp b/src/llama.cpp
index a55254c0..2b9a1b1a 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -8351,25 +8351,40 @@ static struct ggml_tensor * llm_build_moe_ffn(
experts = ggml_mul(ctx, experts, weights);
- // aggregate experts
- ggml_tensor * moe_out = nullptr;
- for (int i = 0; i < n_expert_used; ++i) {
- ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
- experts->nb[2], i*experts->nb[1]);
-
- if (i == 0) {
- moe_out = cur_expert;
- } else {
- moe_out = ggml_add(ctx, moe_out, cur_expert);
- }
- }
-
if (n_expert_used == 1) {
- // avoid returning a non-contiguous tensor
- moe_out = ggml_cont(ctx, moe_out);
- }
+ return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0));
+ }
+ if (n_expert_used == 2) {
+ return ggml_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0),
+ ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], experts->nb[1]));
+ }
+ return ggml_multi_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0), n_expert_used);
+
+ //// aggregate experts
+ //ggml_tensor * moe_out = nullptr;
+ ////ggml_tensor * first_expert = nullptr;
+ //for (int i = 0; i < n_expert_used; ++i) {
+ // ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
+ // experts->nb[2], i*experts->nb[1]);
+
+ // if (i == 0) {
+ // moe_out = cur_expert;
+ // //first_expert = cur_expert;
+ // //printf("%s: %d: %d x %d x %d x %d | %d x %d x %d x %d\n", __func__, ggml_is_contiguous(first_expert),
+ // // (int)cur_expert->ne[0], (int)cur_expert->ne[1], (int)cur_expert->ne[2], (int)cur_expert->ne[3],
+ // // (int)cur_expert->nb[0], (int)cur_expert->nb[1], (int)cur_expert->nb[2], (int)cur_expert->nb[3]);
+ // } else {
+ // moe_out = ggml_add(ctx, moe_out, cur_expert);
+ // //printf("%s: %d %d\n", __func__, ggml_is_contiguous(cur_expert), ggml_are_same_shape(cur_expert, first_expert));
+ // }
+ //}
+
+ //if (n_expert_used == 1) {
+ // // avoid returning a non-contiguous tensor
+ // moe_out = ggml_cont(ctx, moe_out);
+ //}
- return moe_out;
+ //return moe_out;
}
static struct ggml_tensor * llm_build_kqv(
@@ -9011,6 +9026,7 @@ struct llm_build_context {
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
if (hparams.f_attention_scale != 0) {
+ // Why is hparams.f_attention_scale not simply absorbed into model.layers[il].wq ?
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
}
cb(Qcur, "Qcur", il);
@@ -9062,6 +9078,7 @@ struct llm_build_context {
// For Granite architecture
if (hparams.f_residual_scale) {
+ // Why is hparams.f_residual_scale not simply absorbed into model.layers[il].wv ?
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
}
@@ -9103,6 +9120,7 @@ struct llm_build_context {
// For Granite architecture
if (hparams.f_residual_scale) {
+ // Why is hparams.f_residual_scale not simply absorbed into model.layers[il].ffn_down_exps ?
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
}
@@ -9128,6 +9146,7 @@ struct llm_build_context {
// For Granite architecture
if (hparams.f_logit_scale) {
+ // Why is hparams.f_logit_scale not simply absorbed into model.output ?
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
}