diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-02-23 14:31:11 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-02-23 14:31:11 +0200 |
commit | ac1d259b93eccfa7371c6b00c5749400ff2b2aea (patch) | |
tree | fe8bb34c9dcbea805595c5087f00b188bb89fc05 | |
parent | 46bf73a37f1aabe6f0b40365b0c7b2ba831905f5 (diff) |
Fused MoE ffn_up and ffn_gate (#229)
* Fusing MoE up * unary(gate)
* Fusing MoE up * unary(gate): CUDA
We get ~13% speedup for PP-512 and ~2% for TG-128
for DeepSeek-Lite
* On CUDA also fuse MoE down * (up * unary(gate))
in case the MUL_MAT_ID op for the down experts is the next
op in the graph.
* Command line option to enable fused MoE up*unary(gate)
* Add fmoe option to llama-bench
* Adding forgotten gelu, relu, silu on ARM
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | common/common.cpp | 7 | ||||
-rw-r--r-- | common/common.h | 1 | ||||
-rw-r--r-- | examples/llama-bench/llama-bench.cpp | 33 | ||||
-rw-r--r-- | ggml/include/ggml.h | 10 | ||||
-rw-r--r-- | ggml/src/ggml-cuda.cu | 263 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/unary.cu | 36 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/unary.cuh | 2 | ||||
-rw-r--r-- | ggml/src/ggml.c | 157 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.cpp | 260 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_mul_mat.h | 5 | ||||
-rw-r--r-- | include/llama.h | 1 | ||||
-rw-r--r-- | src/llama.cpp | 44 |
12 files changed, 734 insertions, 85 deletions
diff --git a/common/common.cpp b/common/common.cpp index 6bf6e4f9..f975aee3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -817,6 +817,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.mla_attn = true; return true; } + if (arg == "-fmoe" || arg == "--fused-moe") { + params.fused_moe_up_gate = true; + return true; + } if (arg == "-co" || arg == "--color") { params.use_color = true; return true; @@ -1466,6 +1470,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %s)", params.mla_attn ? "enabled" : "disabled" }); + options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" "(default: '%s')", params.prompt.c_str() }); @@ -2303,6 +2308,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; cparams.mla_attn = params.mla_attn; + cparams.fused_moe_up_gate = params.fused_moe_up_gate; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -3301,6 +3307,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); fprintf(stream, "mla_attn: %s # default: false\n", params.mla_attn ? "true" : "false"); + fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index b5b67986..f86a58cb 100644 --- a/common/common.h +++ b/common/common.h @@ -175,6 +175,7 @@ struct gpt_params { bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention bool mla_attn = false; // MLA + bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 0222c213..b0790e20 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -241,6 +241,7 @@ struct cmd_params { bool verbose; bool warmup; bool repack = false; + bool fmoe = false; output_formats output_format; output_formats output_format_stderr; }; @@ -271,6 +272,7 @@ static const cmd_params cmd_params_defaults = { /* verbose */ false, /* warmup */ true, /* repack */ false, + /* fmoe */ false, /* output_format */ MARKDOWN, /* output_format_stderr */ NONE, }; @@ -307,6 +309,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0"); printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0"); + printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0"); printf("\n"); printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); } @@ -607,6 +610,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.repack = std::stoi(argv[i]); + } else if (arg == "-fmoe" || arg == "--fused-moe") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.fmoe = std::stoi(argv[i]); } else { invalid_param = true; break; @@ -675,6 +684,7 @@ struct cmd_params_instance { bool use_mmap; bool embeddings; bool repack = false; + bool fmoe = false; llama_model_params to_llama_mparams() const { llama_model_params mparams = llama_model_default_params(); @@ -714,6 +724,7 @@ struct cmd_params_instance { cparams.offload_kqv = !no_kv_offload; cparams.flash_attn = flash_attn; cparams.mla_attn = mla_attn; + cparams.fused_moe_up_gate = fmoe; cparams.embeddings = embeddings; return cparams; @@ -765,6 +776,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, }; instances.push_back(instance); } @@ -794,6 +806,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, }; instances.push_back(instance); } @@ -823,6 +836,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, }; instances.push_back(instance); } @@ -852,6 +866,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, }; instances.push_back(instance); } @@ -892,6 +907,7 @@ struct test { bool use_mmap; bool embeddings; bool repack = false; + bool fmoe = false; int n_prompt; int n_gen; std::string test_time; @@ -922,6 +938,7 @@ struct test { use_mmap = inst.use_mmap; embeddings = inst.embeddings; repack = inst.repack; + fmoe = inst.fmoe; n_prompt = inst.n_prompt; n_gen = inst.n_gen; test_kind = inst.test_kind; @@ -1012,7 +1029,7 @@ struct test { "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", - "tensor_split", "use_mmap", "embeddings", "repack", + "tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", "test", @@ -1033,7 +1050,8 @@ struct test { } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") { + field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || + field == "fused_moe") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -1068,7 +1086,7 @@ struct test { std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), - tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), + tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), std::to_string(avg_ts()), std::to_string(stdev_ts()), @@ -1240,6 +1258,9 @@ struct markdown_printer : public printer { if (field == "repack") { return 3; } + if (field == "fused_moe") { + return 4; + } if (field == "test") { return 13; } @@ -1277,6 +1298,9 @@ struct markdown_printer : public printer { if (field == "repack") { return "rtr"; } + if (field == "fused_moe") { + return "fmoe"; + } if (field == "embeddings") { return "embd"; } @@ -1338,6 +1362,9 @@ struct markdown_printer : public printer { if (params.repack != cmd_params_defaults.repack) { fields.emplace_back("repack"); } + if (params.fmoe != cmd_params_defaults.fmoe) { + fields.emplace_back("fused_moe"); + } fields.emplace_back("test"); fields.emplace_back("t/s"); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index d2131a15..d12b90d0 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -567,6 +567,7 @@ extern "C" { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID, GGML_OP_OUT_PROD, + GGML_OP_MOE_FUSED_UP_GATE, GGML_OP_SCALE, GGML_OP_SET, @@ -1320,6 +1321,15 @@ extern "C" { struct ggml_tensor * b, struct ggml_tensor * ids); + // MoE up + gate + unary + GGML_API struct ggml_tensor * ggml_moe_up_gate( + struct ggml_context * ctx, + struct ggml_tensor * as_up, + struct ggml_tensor * as_gate, + struct ggml_tensor * b, + struct ggml_tensor * ids, + enum ggml_unary_op op); + // A: m columns, n rows, // B: p columns, n rows, // result is m columns, p rows diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index e38e9568..26d06d56 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2195,7 +2195,252 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } -static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) { +static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) { + const ggml_tensor * src0_1 = dst->src[0]; + const ggml_tensor * src0_2 = dst->src[1]; + const ggml_tensor * src0 = src0_1; + const ggml_tensor * src1 = dst->src[2]; + const ggml_tensor * ids = dst->src[3]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_1->buffer) && "mul_mat_id does not support split buffers"); + GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_2->buffer) && "mul_mat_id does not support split buffers"); + + cudaStream_t stream = ctx.stream(); + + const int64_t n_as = ne02; + const int64_t n_ids = ids->ne[0]; + + std::vector<char> ids_host(ggml_nbytes(ids)); + const char * ids_dev = (const char *) ids->data; + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + ggml_tensor src0_1_row = *src0_1; + ggml_tensor src0_2_row = *src0_2; + ggml_tensor src1_row = *src1; + ggml_tensor dst_row = *dst; + ggml_tensor final_dst; + ggml_tensor final_src; + + char * src0_1_original = (char *) src0_1->data; + char * src0_2_original = (char *) src0_2->data; + char * src1_original = (char *) src1->data; + char * dst_original = (char *) dst->data; + + src0_1_row.ne[2] = 1; + src0_1_row.ne[3] = 1; + src0_1_row.nb[3] = nb02; + src0_2_row.ne[2] = 1; + src0_2_row.ne[3] = 1; + src0_2_row.nb[3] = nb02; + + src1_row.ne[1] = 1; + src1_row.ne[2] = 1; + src1_row.ne[3] = 1; + src1_row.nb[2] = nb11; + src1_row.nb[3] = nb11; + + dst_row.ne[1] = 1; + dst_row.ne[2] = 1; + dst_row.ne[3] = 1; + dst_row.nb[2] = nb1; + dst_row.nb[3] = nb1; + + bool fuse_down = false; + if (next && next->op == GGML_OP_MUL_MAT_ID) { + //printf("Fusing MoE down gemm\n"); + fuse_down = true; + final_dst = *next; + final_dst.ne[1] = final_dst.ne[2] = final_dst.ne[3] = 1; + final_dst.nb[2] = final_dst.nb[3] = final_dst.nb[1]; + final_src = *next->src[0]; + //printf("next->src[0]: %s, %d x %d x %d x %d and %d x %d x %d x %d\n", ggml_type_name(next->src[0]->type), + // (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3], + // (int)next->src[0]->nb[0], (int)next->src[0]->nb[1], (int)next->src[0]->nb[2], (int)next->src[0]->nb[3]); + final_src.ne[2] = final_src.ne[3] = 1; + final_src.nb[3] = final_src.nb[2]; + } + + if (ne12 == 1) { + ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); + ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); + if (fuse_down) { + final_dst.src[1] = &dst_row; + } + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + const int64_t i11 = id % ne11; + const int64_t i12 = iid1; + + const int64_t i1 = id; + const int64_t i2 = i12; + + src0_1_row.data = src0_1_original + i02*nb02; + src0_2_row.data = src0_2_original + i02*nb02; + src1_row.data = src1_original + i11*nb11 + i12*nb12; + //dst_row.data = dst_original + i1*nb1 + i2*nb2; + + dst_row.data = dst_up_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); + + dst_row.data = dst_gate_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); + + if (fuse_down) { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + CUDA_CHECK(cudaGetLastError()); + + final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; + final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2]; + ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); + CUDA_CHECK(cudaGetLastError()); + + } else { + + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2)); + CUDA_CHECK(cudaGetLastError()); + + } + } + } + } else { + ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); + ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc<char> final_dst_contiguous(ctx.pool()); + if (fuse_down) { + final_dst.data = final_dst_contiguous.alloc(ggml_nelements(next)); + final_dst.src[1] = &dst_row; + } + + src1_row.data = src1_contiguous.get(); + + bool first = false; //true; + + for (int64_t i02 = 0; i02 < n_as; i02++) { + int64_t num_src1_rows = 0; + + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + + if (row_id_i != i02) { + continue; + } + + num_src1_rows++; + } + } + + if (num_src1_rows == 0) { + continue; + } + + ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1); + ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows); + CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); + + { + dim3 block_dims(std::min((unsigned int)ne10, 768u)); + dim3 grid_dims(ids->ne[1], n_ids); + k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>( + src1_original, src1_contiguous.get(), + dev_cur_src1_row.get(), dev_row_mapping.get(), + ids_dev, i02, ids->nb[1], ids->nb[0], + ne11, ne10, + nb11, nb12); + CUDA_CHECK(cudaGetLastError()); + } + + src0_1_row.data = src0_1_original + i02*nb02; + src0_2_row.data = src0_2_original + i02*nb02; + + GGML_ASSERT(nb11 == sizeof(float)*ne10); + GGML_ASSERT(nb1 == sizeof(float)*ne0); + + src1_row.ne[1] = num_src1_rows; + src1_row.nb[1] = nb11; + src1_row.nb[2] = num_src1_rows*nb11; + src1_row.nb[3] = num_src1_rows*nb11; + + dst_row.ne[1] = num_src1_rows; + dst_row.nb[1] = nb1; + dst_row.nb[2] = num_src1_rows*nb1; + dst_row.nb[3] = num_src1_rows*nb1; + + dst_row.data = dst_up_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); + + dst_row.data = dst_gate_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); + + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + CUDA_CHECK(cudaGetLastError()); + + if (fuse_down) { + + final_dst.ne[1] = num_src1_rows; + final_dst.nb[1] = final_dst.ne[0]*sizeof(float); + final_dst.nb[2] = final_dst.nb[3] = num_src1_rows*final_dst.nb[1]; + final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; + if (first) { + printf("Fusing down for %d rows: (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", (int)num_src1_rows, + (int)next->ne[0], (int)next->ne[1], (int)next->ne[2], (int)next->ne[3], + (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3], + (int)next->src[1]->ne[0], (int)next->src[1]->ne[1], (int)next->src[1]->ne[2], (int)next->src[1]->ne[3]); + printf(" using (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", + (int)final_dst.ne[0], (int)final_dst.ne[1], (int)final_dst.ne[2], (int)final_dst.ne[3], + (int)final_src.ne[0], (int)final_src.ne[1], (int)final_src.ne[2], (int)final_src.ne[3], + (int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]); + first = false; + } + ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); + //ggml_cuda_mul_mat(ctx, next->src[0], &dst_row, &final_dst); + CUDA_CHECK(cudaGetLastError()); + + dim3 block_dims(std::min((unsigned int)next->ne[0], 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>( + (char *)next->data, final_dst_contiguous.get(), + dev_row_mapping.get(), + next->ne[0], + next->nb[1], next->nb[2]); + CUDA_CHECK(cudaGetLastError()); + + } + else { + + dim3 block_dims(std::min((unsigned int)ne0, 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>( + dst_original, dst_gate_contiguous.get(), + dev_row_mapping.get(), + ne0, + nb1, nb2); + CUDA_CHECK(cudaGetLastError()); + } + } + } + + return fuse_down; +} + +static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next, bool& skip_next) { // why is this here instead of mul_mat? if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) { ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); @@ -2309,6 +2554,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_MUL_MAT_ID: ggml_cuda_mul_mat_id(ctx, dst); break; + case GGML_OP_MOE_FUSED_UP_GATE: + skip_next = ggml_cuda_up_gate_unary(ctx, dst, next); + break; case GGML_OP_SCALE: ggml_cuda_op_scale(ctx, dst); break; @@ -2595,7 +2843,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t #endif } - if (node->op == GGML_OP_MUL_MAT_ID) { + if (node->op == GGML_OP_MUL_MAT_ID || node->op == GGML_OP_MOE_FUSED_UP_GATE) { use_cuda_graph = false; // This node type is not supported by CUDA graph capture #ifndef NDEBUG GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to mul_mat_id\n", __func__); @@ -2666,6 +2914,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t if (!use_cuda_graph || cuda_graph_update_required) { for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; + ggml_tensor * next = i < cgraph->n_nodes-1 ? cgraph->nodes[i+1] : nullptr; if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; @@ -2680,11 +2929,13 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } #endif - bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); + bool skip_next = false; + bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, skip_next); if (!ok) { GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); } GGML_ASSERT(ok); + if (skip_next) ++i; } } @@ -2809,9 +3060,13 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_FUSED_MUL_UNARY: return ggml_is_contiguous(op->src[0]); case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: + case GGML_OP_MOE_FUSED_UP_GATE: { struct ggml_tensor * a = op->src[0]; - struct ggml_tensor * b = op->src[1]; + struct ggml_tensor * b = op->op == GGML_OP_MOE_FUSED_UP_GATE ? op->src[2] : op->src[1]; + if (op->op == GGML_OP_MOE_FUSED_UP_GATE && a->type != op->src[1]->type) { + return false; + } if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { return false; } diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 8ffddd6d..c422abbc 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -297,6 +297,19 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { swiglu_f32_cuda(src0_d, dst_d, ggml_nelements(dst), dst->ne[0], src0->nb[1]/sizeof(float), stream); } +void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op, + int64_t nelements, const float * src0_d, const float * src1_d, float * dst_d) { + + cudaStream_t stream = ctx.stream(); + + switch (op) { + case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, nelements, stream); break; + case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, nelements, stream); break; + case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, nelements, stream); break; + default: GGML_ASSERT(false); + } +} + void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -304,19 +317,22 @@ void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, src1)); - cudaStream_t stream = ctx.stream(); ggml_unary_op op = (ggml_unary_op)dst->op_params[0]; - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - float * dst_d = (float *)dst->data; + ggml_fused_mul_unary(ctx, op, ggml_nelements(dst), (const float *)src0->data, (const float *)src1->data, (float *)dst->data); - switch (op) { - case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; - case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; - case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; - default: GGML_ASSERT(false); - } + //cudaStream_t stream = ctx.stream(); + + //const float * src0_d = (const float *)src0->data; + //const float * src1_d = (const float *)src1->data; + //float * dst_d = (float *)dst->data; + + //switch (op) { + // case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + // case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + // case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + // default: GGML_ASSERT(false); + //} } void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 0235a319..e55c4262 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -36,5 +36,7 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op, + int64_t nelements, const float * x, const float * y, float * z); void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index eb39d574..8efe2653 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3845,6 +3845,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "MUL_MAT", "MUL_MAT_ID", "OUT_PROD", + "MOE_FUSED_UP_GATE", "SCALE", "SET", @@ -3904,7 +3905,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); +static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3938,6 +3939,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "X*Y", "X[i]*Y", "X*Y", + "X*Y1&X*Y2", "x*v", "y-\\>view(x)", @@ -3997,7 +3999,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); +static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6768,6 +6770,51 @@ struct ggml_tensor * ggml_mul_mat_id( return result; } +struct ggml_tensor * ggml_moe_up_gate( + struct ggml_context * ctx, + struct ggml_tensor * as_up, + struct ggml_tensor * as_gate, + struct ggml_tensor * b, + struct ggml_tensor * ids, + enum ggml_unary_op op) { + if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) { + struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids); + struct ggml_tensor * result_gate = ggml_mul_mat_id(ctx, as_gate, b, ids); + return ggml_fused_mul_unary(ctx, result_gate, result_up, op); + } + GGML_ASSERT(!ggml_is_transposed(as_up)); + GGML_ASSERT(!ggml_is_transposed(as_gate)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert) + GGML_ASSERT(b->ne[3] == 1); // b is 3d + GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d + GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row + GGML_ASSERT(as_up->ne[0] == b->ne[0]); // can_mul_mat + GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast + + bool is_node = false; + + if (as_up->grad || as_gate->grad || b->grad) { + is_node = true; + } + + const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_MOE_FUSED_UP_GATE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = as_up; + result->src[1] = as_gate; + result->src[2] = b; + result->src[3] = ids; + + ggml_set_op_params_i32(result, 0, (int32_t) op); + + return result; +} + + // ggml_out_prod struct ggml_tensor * ggml_out_prod( @@ -14584,20 +14631,17 @@ IQK_MulMat_Not_Available:; #if GGML_USE_IQK_MULMAT static void ggml_compute_forward_mul_mat_id_up_gate( const struct ggml_compute_params * params, - struct ggml_tensor * dst1, - struct ggml_tensor * dst2) { + struct ggml_tensor * dst) { - GGML_ASSERT(dst1->src[1] == dst2->src[1]); - GGML_ASSERT(dst1->src[2] == dst2->src[2]); - GGML_ASSERT(dst1->src[0]->type == dst2->src[0]->type); - GGML_ASSERT(dst1->type == GGML_TYPE_F32 && dst2->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == dst->src[1]->type); + GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst->src[1])); + GGML_ASSERT(dst->type == GGML_TYPE_F32); - const struct ggml_tensor * src1 = dst1->src[1]; - const struct ggml_tensor * ids = dst1->src[2]; - const struct ggml_tensor * src0_1 = dst1->src[0]; - const struct ggml_tensor * src0_2 = dst2->src[0]; - const struct ggml_tensor * src0 = src0_1; - const struct ggml_tensor * dst = dst1; // so GGML_TENSOR_BINARY_OP_LOCALS works + const struct ggml_tensor * src1 = dst->src[2]; + const struct ggml_tensor * ids = dst->src[3]; + const struct ggml_tensor * src0_1 = dst->src[0]; + const struct ggml_tensor * src0_2 = dst->src[1]; + const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works GGML_TENSOR_BINARY_OP_LOCALS @@ -14680,6 +14724,9 @@ static void ggml_compute_forward_mul_mat_id_up_gate( ggml_barrier(params->shared); + + // so GGML_TENSOR_BINARY_OP_LOCALS works + // compute each matrix multiplication in sequence for (int cur_a = 0; cur_a < n_as; ++cur_a) { const int64_t cne1 = matrix_row_counts[cur_a]; @@ -14696,28 +14743,34 @@ static void ggml_compute_forward_mul_mat_id_up_gate( const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = cne1; // src1 rows - - if (nth%2 == 0) { - const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur; - void * dst_d = ith%2 == 0 ? dst1->data : dst2->data; - if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, - type, src0_d, nb01, - vec_dot_type, (const char *)wdata, row_size, - (float *)dst_d, nb1, nb2, - matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error"); - - } else { - if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, - src0_1->type, (const char *)src0_1_cur, nb01, - vec_dot_type, (const char *)wdata, row_size, - (float *)dst1->data, nb1, nb2, - matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); - if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, - src0_2->type, (const char *)src0_2_cur, nb01, - vec_dot_type, (const char *)wdata, row_size, - (float *)dst2->data, nb1, nb2, - matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); - } + // + if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0], + type, src0_1_cur, src0_2_cur, nb01, + vec_dot_type, (const char *)wdata, row_size, + (float *)dst->data, nb1, nb2, + matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); + +// if (nth%2 == 0) { +// const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur; +// void * dst_d = ith%2 == 0 ? dst1->data : dst2->data; +// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, +// type, src0_d, nb01, +// vec_dot_type, (const char *)wdata, row_size, +// (float *)dst_d, nb1, nb2, +// matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error"); +// +// } else { +// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, +// src0_1->type, (const char *)src0_1_cur, nb01, +// vec_dot_type, (const char *)wdata, row_size, +// (float *)dst1->data, nb1, nb2, +// matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); +// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, +// src0_2->type, (const char *)src0_2_cur, nb01, +// vec_dot_type, (const char *)wdata, row_size, +// (float *)dst2->data, nb1, nb2, +// matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); +// } } #undef MMID_MATRIX_ROW @@ -19152,6 +19205,7 @@ static void ggml_compute_forward_cross_entropy_loss_back( static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) { GGML_ASSERT(params); + GGML_UNUSED(next); if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { return false; @@ -19269,16 +19323,12 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_MUL_MAT_ID: { -#if GGML_USE_IQK_MULMAT - if (next && next->op == GGML_OP_MUL_MAT_ID && tensor->src[1] == next->src[1] && - tensor->src[0]->type == next->src[0]->type) { - ggml_compute_forward_mul_mat_id_up_gate(params, tensor, next); - skip_next = true; - break; - } -#endif ggml_compute_forward_mul_mat_id(params, tensor); } break; + case GGML_OP_MOE_FUSED_UP_GATE: + { + ggml_compute_forward_mul_mat_id_up_gate(params, tensor); + } break; case GGML_OP_OUT_PROD: { ggml_compute_forward_out_prod(params, tensor); @@ -20036,6 +20086,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_OP_MOE_FUSED_UP_GATE: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_OUT_PROD: { GGML_ABORT("fatal error"); // TODO: not implemented @@ -21046,6 +21100,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_CONCAT: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: + case GGML_OP_MOE_FUSED_UP_GATE: case GGML_OP_OUT_PROD: { n_tasks = n_threads; @@ -21249,6 +21304,20 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += n_as * sizeof(int64_t); // matrix_row_counts cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows } break; + case GGML_OP_MOE_FUSED_UP_GATE: + { + cur = 0; + const struct ggml_tensor * src0 = node->src[0]; + const struct ggml_tensor * src2 = node->src[2]; + const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; + if (src2->type != vec_dot_type) { + cur += ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]); + } + const int n_as = src0->ne[2]; + cur += GGML_PAD(cur, sizeof(int64_t)); // align + cur += n_as * sizeof(int64_t); // matrix_row_counts + cur += n_as * src2->ne[2] * sizeof(int64_t); // matrix_rows + } break; case GGML_OP_OUT_PROD: { if (ggml_is_quantized(node->src[0]->type)) { diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ed7309cd..0f7cd1e5 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -217,6 +217,118 @@ struct MulMat { funcs[n_left-1](n, vx, bx, info, nrc_x); } } + inline void gelu(int n, const float * src, float * dst); + inline void relu(int n, const float * src, float * dst); + inline void silu(int n, const float * src, float * dst); + inline void activate(ggml_unary_op op, int n, const float * src, float * dst) { + if (op == GGML_UNARY_OP_GELU) gelu(n, src, dst); + else if (op == GGML_UNARY_OP_RELU) relu(n, src, dst); + else if (op == GGML_UNARY_OP_SILU) silu(n, src, dst); + else GGML_ABORT("fatal error"); + } + inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx, DataInfo& info, int nrc_x, int nrc_y, int unary_op) { +#ifdef __aarch64__ + constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small) +#else + constexpr int k_x_step = 64; // This works best on my Ryzen-7950X (but differences to other tile size are small) +#endif + auto op = ggml_unary_op(unary_op); + float tmp[k_x_step*16]; + if (func16 && nrc_y >= 16) { + int n_step = (nrc_y - info.cur_y)/16; + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + for (int iy = 0; iy < n_step; ++iy) { + func16(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < 16; ++ky) { + activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); + } + func16(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < 16; ++ky) { + auto result = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; + } + this_info.cur_y += 16; + } + } + info.cur_y += 16 * n_step; + if (info.cur_y == nrc_y) return; + } + int ny = funcs.size(); + while (!funcs[ny-1] && ny > 0) --ny; + int n_left = nrc_y - info.cur_y; + int n_step = n_left/ny; + if (n_step > 0) { + if (n_step*ny != n_left) { + ++n_step; + int ny1 = n_left/n_step; + int ny2 = ny1 + 1; + int my1 = n_step*ny2 - n_left; + int my2 = n_step - my1; + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + for (int iy = 0; iy < my1; ++iy) { + funcs[ny1-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny1; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); + funcs[ny1-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny1; ++ky) { + auto result = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; + } + this_info.cur_y += ny1; + } + for (int iy = 0; iy < my2; ++iy) { + funcs[ny2-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny2; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); + funcs[ny2-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny2; ++ky) { + auto result = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; + } + this_info.cur_y += ny2; + } + } + info.cur_y += n_left; + } + else { + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + for (int iy = 0; iy < n_step; ++iy) { + funcs[ny-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); + funcs[ny-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny; ++ky) { + auto result = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; + } + this_info.cur_y += ny; + } + } + info.cur_y += ny * n_step; + } + } + n_left = nrc_y - info.cur_y; + if (n_left > 0) { + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + funcs[n_left-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < n_left; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); + funcs[n_left-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < n_left; ++ky) { + auto result = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; + } + } + } + } static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny); static inline int num_rows(ggml_type type) { #ifdef HAVE_FANCY_SIMD @@ -414,6 +526,34 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, return true; } +bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, + int typeA, const void * Aup, const void * Agate, long strideA, + int typeB, const void * B, long strideB, + float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { + + const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping; + assert(row_mapping != nullptr); + + MulMat mm; + if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { + return false; + } + size_t row_size_qx = strideA; + size_t row_size_qy = strideB; + auto num_rows = MulMat::num_rows(ggml_type(typeA)); + GGML_ASSERT(Nx%num_rows == 0); + auto nrc_x = (Nx/num_rows + nth - 1)/nth; + auto first_x = ith*nrc_x; + if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x; + first_x *= num_rows; + nrc_x *= num_rows; + DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), + row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; + mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny, unary_op); + return true; +} + + namespace { inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) { @@ -14660,6 +14800,45 @@ inline float32x4_t v_tanh(float16x8_t x) { auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x))); return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); } +inline float32x4_t v_silu(float32x4_t x) { + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t zero = vdupq_n_f32(0.0f); + const float32x4_t neg_x = vsubq_f32(zero, x); + const float32x4_t exp_neg_x = v_expf(neg_x); + const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + return vdivq_f32(x, one_plus_exp_neg_x); +} +inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) { + const float32x4_t one = vdupq_n_f32(1.0f); + float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x)); + arg = vmulq_f32(arg, vmulq_f32(x, c2)); + float32x4_t exp_arg = v_expf(arg); + float32x4_t gelu = vmulq_f32(x, vdivq_f32(exp_arg, vaddq_f32(exp_arg, one))); + uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f)); + return vbslq_f32(mask, x, gelu); +} + +void MulMat::gelu(int n, const float * x, float * y) { + constexpr float GELU_COEF_A = 0.044715f; + constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + int i = 0; + auto c1 = vdupq_n_f32(GELU_COEF_A); + auto c2 = vdupq_n_f32(2.f*SQRT_2_OVER_PI); + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, v_gelu(vld1q_f32(x + i), c1, c2)); + } + for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i]))); +} + +void MulMat::silu(int n, const float * x, float * y) { + int i = 0; + for (; i + 3 < n; i += 4) vst1q_f32(y + i, v_silu(vld1q_f32(x + i))); + for (; i < n; ++i) y[i] = x[i]/(1.0f + expf(-x[i])); +} + +void MulMat::relu(int n, const float * x, float * y) { + for (int j = 0; j < n; ++j) y[j] = x[j] > 0 ? x[j] : 0; +} #endif #if defined(__AVX512F__) && defined(__AVX512DQ__) @@ -14702,6 +14881,24 @@ inline __m512 v_tanh(__m512 x) { const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); return _mm512_mask_blend_ps(mask, res, one); } +inline __m512 v_gelu(__m512 x, __m512 c1, __m512 c2) { + const __m512 one = _mm512_set1_ps(1.0f); + __m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one); + //__m512 arg = _mm512_add_ps(one, _mm512_mul_ps(_mm512_mul_ps(x, x), c1)); + arg = _mm512_mul_ps(arg, _mm512_mul_ps(c2, x)); + const __mmask16 mask = _mm512_cmp_ps_mask(arg, _mm512_set1_ps(30.f), _CMP_GT_OQ); + const __m512 exp_arg = v_expf(arg); + const __m512 ratio = _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one)); + return _mm512_mul_ps(x, _mm512_mask_blend_ps(mask, ratio, one)); +} +inline static __m512 v_silu(__m512 x) { + const __m512 one = _mm512_set1_ps(1); + const __m512 zero = _mm512_setzero_ps(); + const __m512 neg_x = _mm512_sub_ps(zero, x); + const __m512 exp_neg_x = v_expf(neg_x); + const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); + return _mm512_div_ps(x, one_plus_exp_neg_x); +} #endif #if defined(__AVX2__) && defined(__FMA__) @@ -14755,6 +14952,61 @@ inline __m256 v_tanh(__m256 x) { const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ); return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res)); } +inline static __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) { + const __m256 one = _mm256_set1_ps(1.0f); + const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ); + __m256 arg = _mm256_add_ps(one, _mm256_mul_ps(_mm256_mul_ps(x, x), c1)); + arg = _mm256_mul_ps(arg, _mm256_mul_ps(x, c2)); + __m256 exp_arg = v_expf(arg); + __m256 gelu = _mm256_mul_ps(x, _mm256_div_ps(exp_arg, _mm256_add_ps(exp_arg, one))); + return _mm256_or_ps(_mm256_and_ps(mask, x), _mm256_andnot_ps(mask, gelu)); +} +inline static __m256 v_silu(__m256 x) { + const __m256 one = _mm256_set1_ps(1); + const __m256 zero = _mm256_setzero_ps(); + const __m256 neg_x = _mm256_sub_ps(zero, x); + const __m256 exp_neg_x = v_expf(neg_x); + const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); + return _mm256_div_ps(x, one_plus_exp_neg_x); +} + +void MulMat::gelu(int n, const float * x, float * y) { + constexpr float GELU_COEF_A = 0.044715f; + constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + //GGML_ASSERT(n%8 == 0); + int i = 0; +#if defined __AVX512F__ && defined __AVX512DQ__ + { + __m512 c1 = _mm512_set1_ps(GELU_COEF_A); + __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI); + for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_gelu(_mm512_loadu_ps(x + i), c1, c2)); + } +#endif +#if defined __AVX2__ && defined __FMA__ + if (i + 7 < n) { + __m256 c1 = _mm256_set1_ps(GELU_COEF_A); + __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI); + for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_gelu(_mm256_loadu_ps(x + i), c1, c2)); + + } +#endif + for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i]))); +} + +void MulMat::silu(int n, const float * x, float * y) { + int i = 0; +#if defined __AVX512F__ && defined __AVX512DQ__ + for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_silu(_mm512_loadu_ps(x + i))); +#endif +#if defined __AVX2__ && defined __FMA__ + for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_silu(_mm256_loadu_ps(x + i))); +#endif + for (; i < n; ++i) y[i] = x[i]/(1.0f + expf(-x[i])); +} + +void MulMat::relu(int n, const float * x, float * y) { + for (int j = 0; j < n; ++j) y[j] = x[j] > 0 ? x[j] : 0; +} #endif } // namespace @@ -17107,6 +17359,14 @@ bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const return false; } +bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*ne00*/, int /*ne11*/, int /*unary_op*/, + int /*typeA*/, const void * /*Aup*/, const void * /*Agate*/, long /*strideA*/, + int /*typeB*/, const void * /*B*/, long /*strideB*/, + float * /*C*/, long /*nb1*/, long /*nb2*/, const void * /*vrow_mapping*/, int /*ith*/, int /*nth*/) { + return false; +} + + bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k [[maybe_unused]] int int_type_v, // type of v [[maybe_unused]] int D, // head size diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index d5f340b2..767f89cf 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -28,6 +28,11 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); +bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, + int typeA, const void * Aup, const void * Agate, long strideA, + int typeB, const void * B, long strideB, + float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); + bool iqk_flash_attn_noalibi(int type_k, // type of k int type_v, // type of v int Dk, // K head size diff --git a/include/llama.h b/include/llama.h index b5ad65e7..23e32642 100644 --- a/include/llama.h +++ b/include/llama.h @@ -377,6 +377,7 @@ extern "C" { bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] bool mla_attn; // whether to use MLA attention [EXPERIMENTAL] + bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL] // Abort callback // if it returns true, execution of llama_decode() will be aborted diff --git a/src/llama.cpp b/src/llama.cpp index 28e887ee..eed7aa61 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2516,6 +2516,7 @@ struct llama_cparams { bool offload_kqv; bool flash_attn; bool mla_attn; + bool fused_moe_up_gate; enum llama_pooling_type pooling_type; @@ -8628,30 +8629,20 @@ llm_expert_gating_func_type gating_op, } cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); - ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(up, "ffn_moe_up", il); - - ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(gate, "ffn_moe_gate", il); - - // This is equivalent to the commented out code below - ggml_tensor * par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); - - //switch (type_op) { - // case LLM_FFN_SILU: - // { - // gate = ggml_silu(ctx, gate); - // cb(gate, "ffn_moe_silu", il); - // } break; - // case LLM_FFN_GELU: - // { - // gate = ggml_gelu(ctx, gate); - // cb(gate, "ffn_moe_gelu", il); - // } break; - // default: - // GGML_ABORT("fatal error"); - //} - //ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens] + + ggml_tensor * par; + if (lctx.cparams.fused_moe_up_gate) { + par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + } else { + ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(gate, "ffn_moe_gate", il); + + // This is equivalent to the commented out code below + par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + } cb(par, "ffn_moe_gate_par", il); ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] @@ -8907,6 +8898,7 @@ struct llm_build_context { const bool flash_attn; const bool mla_attn; + const bool fused_moe_up_gate; const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -8958,6 +8950,7 @@ struct llm_build_context { n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), mla_attn (cparams.mla_attn), + fused_moe_up_gate(cparams.fused_moe_up_gate), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -17605,6 +17598,7 @@ struct llama_context_params llama_context_default_params() { /*.offload_kqv =*/ true, /*.flash_attn =*/ false, /*.mla_attn =*/ false, + /*.fused_moe_up_gate =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -17804,6 +17798,7 @@ struct llama_context * llama_new_context_with_model( cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; cparams.mla_attn = params.mla_attn; + cparams.fused_moe_up_gate= params.fused_moe_up_gate; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -17871,6 +17866,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); + LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); |