diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-03-02 13:47:38 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-03-02 13:47:38 +0200 |
commit | a89adaa78f505675be7be6180f419b4b0158c15a (patch) | |
tree | ad82fa3ad44f66f37885bdf0d0d025166eff9535 | |
parent | ef9a3d17b52bb5f6d55f7ef7e05e41e22f2ad81d (diff) |
SER - Smart Expert Reduction (#239)
* A better way to measure the cost of ggml_barrier
* Smart expert selection
* Add ser option to llama-bench
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r-- | common/common.cpp | 35 | ||||
-rw-r--r-- | common/common.h | 2 | ||||
-rw-r--r-- | examples/llama-bench/llama-bench.cpp | 65 | ||||
-rw-r--r-- | ggml/include/ggml.h | 13 | ||||
-rw-r--r-- | ggml/src/ggml-cuda.cu | 16 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/argsort.cu | 47 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/argsort.cuh | 2 | ||||
-rw-r--r-- | ggml/src/ggml-cuda/getrows.cu | 17 | ||||
-rw-r--r-- | ggml/src/ggml.c | 143 | ||||
-rw-r--r-- | include/llama.h | 2 | ||||
-rw-r--r-- | src/llama.cpp | 15 |
11 files changed, 330 insertions, 27 deletions
diff --git a/common/common.cpp b/common/common.cpp index 5c9070da..e62944b9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -322,6 +322,26 @@ bool parse_buft_overrides(const std::string& value, std::vector<llama_model_tens } return true; } +template<class T1, class T2> +std::vector<std::pair<T1,T2>> string_split_pairs(const std::string & str, char delim) { + std::vector<std::pair<T1,T2>> values; + std::istringstream str_stream(str); + std::string token; + T1 first_value; + int i = 0; + while (std::getline(str_stream, token, delim)) { + std::istringstream token_stream(token); + if (i%2 == 0) { + token_stream >> first_value; + } else { + T2 value; + token_stream >> value; + values.emplace_back(first_value, value); + } + i++; + } + return values; +} } #define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; } @@ -864,6 +884,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.fused_moe_up_gate = true; return true; } + if (arg == "-ser" || arg == "--smart-expert-reduction") { + CHECK_ARG + auto values = string_split_pairs<int,float>(argv[i], ','); + if (values.size() == 1) { + params.min_experts = values.front().first; + params.thresh_experts = values.front().second; + } else { + invalid_param = true; + } + return true; + } if (arg == "-co" || arg == "--color") { params.use_color = true; return true; @@ -1523,6 +1554,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn }); options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch}); options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); + options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts}); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" "(default: '%s')", params.prompt.c_str() }); @@ -2368,6 +2400,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.mla_attn = params.mla_attn; cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate = params.fused_moe_up_gate; + cparams.min_experts = params.min_experts; + cparams.thresh_experts = params.thresh_experts; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -3368,6 +3402,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn); fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch); fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); + fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index f35f3558..f6a55885 100644 --- a/common/common.h +++ b/common/common.h @@ -178,6 +178,8 @@ struct gpt_params { int mla_attn = 0; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false) bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models + int min_experts = -1; + float thresh_experts = 0; bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index a08cb762..167525bc 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -215,6 +215,9 @@ static std::string pair_str(const std::pair<int, int> & p) { return buf; } +// Ser = Smart Expert Reduction +using Ser = std::pair<int,float>; + struct cmd_params { std::vector<std::string> model; std::vector<int> n_prompt; @@ -234,6 +237,7 @@ struct cmd_params { std::vector<bool> flash_attn; std::vector<int> mla_attn; std::vector<int> attn_max_batch; + std::vector<Ser> ser; std::vector<std::vector<float>> tensor_split; std::vector<bool> use_mmap; std::vector<bool> embeddings; @@ -267,6 +271,7 @@ static const cmd_params cmd_params_defaults = { /* flash_attn */ {false}, /* mla_attn */ {0}, /* attn_max_batch */ {0}, + /* ser */ {{-1,0.0f}}, /* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -304,6 +309,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); printf(" -amb, --attn-max-batch <i> (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str()); + printf(" -ser, --smart-expert-reduction <i,f>(default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa <distribute|isolate|numactl> (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -387,6 +393,28 @@ bool parse_buft_overrides(const std::string& value, std::vector<llama_model_tens } return true; } +template<class T1, class T2> +std::vector<std::pair<T1,T2>> string_split_pairs(const std::string & str, char delim) { + std::vector<std::pair<T1,T2>> values; + std::istringstream str_stream(str); + std::string token; + T1 first_value; + int i = 0; + while (std::getline(str_stream, token, delim)) { + std::istringstream token_stream(token); + if (i%2 == 0) { + token_stream >> first_value; + if (token_stream.fail()) return {}; + } else { + T2 value; + token_stream >> value; + if (token_stream.fail()) return {}; + values.emplace_back(first_value, value); + } + i++; + } + return values; +} } static cmd_params parse_cmd_params(int argc, char ** argv) { @@ -588,6 +616,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split<int>(argv[i], split_delim); params.attn_max_batch.insert(params.attn_max_batch.end(), p.begin(), p.end()); + } else if (arg == "-ser" || arg == "--smart-expert-reduction") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split_pairs<int,float>(argv[i], split_delim); + params.ser.insert(params.ser.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -701,6 +736,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; } if (params.attn_max_batch.empty()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch; } + if (params.ser.empty()) { params.ser = cmd_params_defaults.ser; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } @@ -739,6 +775,7 @@ struct cmd_params_instance { bool flash_attn; int mla_attn; int attn_max_batch; + Ser ser; std::vector<float> tensor_split; bool use_mmap; bool embeddings; @@ -787,6 +824,8 @@ struct cmd_params_instance { cparams.mla_attn = mla_attn; cparams.attn_max_batch = attn_max_batch; cparams.fused_moe_up_gate = fmoe; + cparams.min_experts = ser.first; + cparams.thresh_experts = ser.second; cparams.embeddings = embeddings; return cparams; @@ -813,6 +852,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param for (const auto & fa : params.flash_attn) for (const auto & mla : params.mla_attn) for (const auto & amb : params.attn_max_batch) + for (const auto & ser : params.ser) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -836,6 +876,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -868,6 +909,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -900,6 +942,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -932,6 +975,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -975,6 +1019,7 @@ struct test { bool flash_attn; int mla_attn; int attn_max_batch; + Ser ser; std::vector<float> tensor_split; bool use_mmap; bool embeddings; @@ -1007,6 +1052,7 @@ struct test { flash_attn = inst.flash_attn; mla_attn = inst.mla_attn; attn_max_batch = inst.attn_max_batch; + ser = inst.ser; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; @@ -1101,7 +1147,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", + "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", "tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -1149,6 +1195,11 @@ struct test { tensor_split_str += "/"; } } + auto ser_to_string = [] (const Ser& ser) { + std::ostringstream str; + str << ser.first << ',' << ser.second; + return str.str(); + }; std::vector<std::string> values = { build_commit, std::to_string(build_number), std::to_string(cuda), std::to_string(vulkan), std::to_string(vulkan), @@ -1158,7 +1209,8 @@ struct test { std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch), + std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), + std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -1328,6 +1380,9 @@ struct markdown_printer : public printer { if (field == "attn_max_batch") { return 5; } + if (field == "ser") { + return 10; + } if (field == "use_mmap") { return 4; } @@ -1371,6 +1426,9 @@ struct markdown_printer : public printer { if (field == "attn_max_batch") { return "amb"; } + if (field == "attn_max_batch") { + return "ser"; + } if (field == "use_mmap") { return "mmap"; } @@ -1432,6 +1490,9 @@ struct markdown_printer : public printer { if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn) { fields.emplace_back("attn_max_batch"); } + if (params.ser.size() > 1 || params.ser != cmd_params_defaults.ser) { + fields.emplace_back("ser"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index d12b90d0..91219d4a 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -597,6 +597,7 @@ extern "C" { GGML_OP_ARANGE, GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, + GGML_OP_ARGSORT_THRESH, GGML_OP_LEAKY_RELU, GGML_OP_SOFTCAP, GGML_OP_SOFT_CAP_MAX, @@ -1913,6 +1914,12 @@ extern "C" { struct ggml_tensor * a, enum ggml_sort_order order); + GGML_API struct ggml_tensor * ggml_argsort_thresh( + struct ggml_context * ctx, + struct ggml_tensor * a, + int min_entries, + float threshold); + GGML_API struct ggml_tensor * ggml_arange( struct ggml_context * ctx, float start, @@ -1924,6 +1931,12 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a, int k); + GGML_API struct ggml_tensor * ggml_top_k_thresh( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k, + int min_entries, + float thresh); #define GGML_KQ_MASK_PAD 32 diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index bc960678..85df0694 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2133,7 +2133,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * for (int64_t id = 0; id < n_ids; id++) { const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - GGML_ASSERT(i02 >= 0 && i02 < n_as); + if (i02 < 0 || i02 >= n_as) continue; + //GGML_ASSERT(i02 >= 0 && i02 < n_as); const int64_t i11 = id % ne11; const int64_t i12 = iid1; @@ -2162,7 +2163,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * for (int64_t id = 0; id < n_ids; id++) { const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + if (i02 < 0 || i02 >= n_as) continue; + //GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); if (row_id_i != i02) { continue; @@ -2301,7 +2303,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor for (int64_t id = 0; id < n_ids; id++) { const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - GGML_ASSERT(i02 >= 0 && i02 < n_as); + if (i02 < 0 || i02 >= n_as) continue; + //GGML_ASSERT(i02 >= 0 && i02 < n_as); const int64_t i11 = id % ne11; const int64_t i12 = iid1; @@ -2362,7 +2365,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor for (int64_t id = 0; id < n_ids; id++) { const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + if (row_id_i < 0 || row_id_i >= n_as) continue; + //GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); if (row_id_i != i02) { continue; @@ -2637,6 +2641,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ARGSORT: ggml_cuda_op_argsort(ctx, dst); break; + case GGML_OP_ARGSORT_THRESH: + ggml_cuda_op_argsort_thresh(ctx, dst); + break; case GGML_OP_FLASH_ATTN_EXT: ggml_cuda_flash_attn_ext(ctx, dst); break; @@ -3252,6 +3259,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: + case GGML_OP_ARGSORT_THRESH: case GGML_OP_ACC: case GGML_OP_GROUP_NORM: case GGML_OP_UPSCALE: diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 607ded85..1734b771 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -8,7 +8,8 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) { } template<ggml_sort_order order> -static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) { +static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, + int min_experts, float thresh_experts) { // bitonic sort int col = threadIdx.x; int row = blockIdx.y; @@ -51,9 +52,18 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n } } - // copy the result to dst without the padding - if (col < ncols) { - dst[row * ncols + col] = dst_row[col]; + if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { + __syncthreads(); + float max_val = x_row[dst_row[0]]; + if (col < ncols) { + dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1; + } + } + else { + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } } } @@ -65,7 +75,8 @@ static int next_power_of_2(int x) { return n; } -static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { +static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, + ggml_sort_order order, int min_experts, float thresh_experts, cudaStream_t stream) { // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); @@ -77,9 +88,9 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); if (order == GGML_SORT_ORDER_ASC) { - k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad); + k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); } else if (order == GGML_SORT_ORDER_DESC) { - k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad); + k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); } else { GGML_ABORT("fatal error"); } @@ -100,5 +111,25 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); + argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, -1, 0.f, stream); +} + +void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + int min_experts = dst->op_params[0]; + float thresh; + memcpy(&thresh, dst->op_params + 1, sizeof(float)); + + argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, GGML_SORT_ORDER_DESC, min_experts, thresh, stream); } diff --git a/ggml/src/ggml-cuda/argsort.cuh b/ggml/src/ggml-cuda/argsort.cuh index 68a00154..4bafa2d7 100644 --- a/ggml/src/ggml-cuda/argsort.cuh +++ b/ggml/src/ggml-cuda/argsort.cuh @@ -1,3 +1,5 @@ #include "common.cuh" void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 4c370323..973b6526 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -4,7 +4,7 @@ template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> static __global__ void k_get_rows( const void * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + int64_t ne00, int64_t ne01, /*int64_t ne02, int64_t ne03,*/ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ /*size_t s0,*/ size_t s1, size_t s2, size_t s3, /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, @@ -31,7 +31,11 @@ static __global__ void k_get_rows( // dequantize dfloat2 v; - dequantize_kernel(src0_row, ib, iqs, v); + if (i01 >= 0 && i01 < ne01) { + dequantize_kernel(src0_row, ib, iqs, v); + } else { + v.x = v.y = 0; + } dst_row[iybs + iqs + 0] = v.x; dst_row[iybs + iqs + y_offset] = v.y; @@ -40,7 +44,7 @@ static __global__ void k_get_rows( template<typename src0_t, typename dst_t> static __global__ void k_get_rows_float( const src0_t * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + int64_t ne00, int64_t ne01, /*int64_t ne02, int64_t ne03,*/ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ /*size_t s0,*/ size_t s1, size_t s2, size_t s3, /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, @@ -56,11 +60,10 @@ static __global__ void k_get_rows_float( } const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03); - dst_row[i00] = src0_row[i00]; + dst_row[i00] = i01 >= 0 && i01 < ne01 ? dst_t(src0_row[i00]) : dst_t(0); } template<int qk, int qr, dequantize_kernel_t dq> @@ -88,7 +91,7 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>( src0_dd, src1_dd, dst_dd, - ne00, /*ne01, ne02, ne03,*/ + ne00, ne01, /*ne02, ne03,*/ /*ne10, ne11,*/ ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, @@ -120,7 +123,7 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr k_get_rows_float<<<block_nums, block_dims, 0, stream>>>( src0_dd, src1_dd, dst_dd, - ne00, /*ne01, ne02, ne03,*/ + ne00, ne01, /*ne02, ne03,*/ /*ne10, ne11,*/ ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7ba5e1ad..31fbc57e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3875,6 +3875,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ARANGE", "TIMESTEP_EMBEDDING", "ARGSORT", + "ARGSORT_THRESH", "LEAKY_RELU", "SOFTCAP", "SOFT_CAP_MAX", @@ -3905,7 +3906,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3969,6 +3970,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "arange(start, stop, step)", "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", + "argsort_thresh(x)", "leaky_relu(x)", "k2*tanh(k1*x)", "soft_max(k2*tanh(k1*x))", @@ -3999,7 +4001,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -8497,6 +8499,27 @@ struct ggml_tensor * ggml_argsort( return result; } +// ggml_argsort + +struct ggml_tensor * ggml_argsort_thresh( + struct ggml_context * ctx, + struct ggml_tensor * a, + int min_entries, + float thresh) { + bool is_node = false; + + //printf("%s: min_entries = %d, thresh = %g\n", __func__, min_entries, (double)thresh); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); + + ggml_set_op_params_i32(result, 0, (int32_t) min_entries); + ggml_set_op_params_f32(result, 1, thresh); + + result->op = GGML_OP_ARGSORT_THRESH; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} // ggml_top_k @@ -8516,6 +8539,32 @@ struct ggml_tensor * ggml_top_k( return result; } +// ggml_top_k_thresh + +struct ggml_tensor * ggml_top_k_thresh( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k, + int min_entries, + float thresh) { + GGML_ASSERT(a->ne[0] >= k); + + //printf("%s: k = %d, min_entries = %d, thresh = %g\n", __func__, k, min_entries, (double)thresh); + struct ggml_tensor * result; + if (min_entries <= 0 || thresh <= 0) { + result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC); + } else { + result = ggml_argsort_thresh(ctx, a, min_entries, thresh); + } + + result = ggml_view_4d(ctx, result, + k, result->ne[1], result->ne[2], result->ne[3], + result->nb[1], result->nb[2], result->nb[3], + 0); + + return result; +} + // ggml_flash_attn_ext struct ggml_tensor * ggml_flash_attn_ext( @@ -14485,7 +14534,8 @@ static void ggml_compute_forward_mul_mat_id( for (int id = 0; id < n_ids; ++id) { const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); - assert(i02 >= 0 && i02 < n_as); + if (i02 < 0 || i02 >= n_as) continue; + //assert(i02 >= 0 && i02 < n_as); MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; matrix_row_counts[i02] += 1; @@ -14737,7 +14787,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate( for (int id = 0; id < n_ids; ++id) { const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); - assert(i02 >= 0 && i02 < n_as); + if (i02 < 0 || i02 >= n_as) continue; + //assert(i02 >= 0 && i02 < n_as); MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; matrix_row_counts[i02] += 1; @@ -15580,7 +15631,11 @@ static void ggml_compute_forward_get_rows_q( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); + if (i01 < 0 || i01 >= ne01) { + memset((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float)); + continue; + } + //assert(i01 >= 0 && i01 < ne01); dequantize_row_q( (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), @@ -17667,6 +17722,75 @@ static void ggml_compute_forward_argsort( } } +// ggml_compute_forward_argsort_thresh + +static void ggml_compute_forward_argsort_thresh_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + int min_entries = ggml_get_op_params_i32(dst, 0); + float thresh = ggml_get_op_params_f32(dst, 1); + + //if (ith == 0) printf("%s: min_entries = %d, thresh = %g\n", __func__, min_entries, (double)thresh); + + for (int64_t i = ith; i < nr; i += nth) { + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + const float * src_data = (float *)((char *) src0->data + i*nb01); + + for (int64_t j = 0; j < ne0; j++) { + dst_data[j] = j; + } + + // C doesn't have a functional sort, so we do a bubble sort instead + for (int64_t j = 0; j < ne0; j++) { + for (int64_t k = j + 1; k < ne0; k++) { + if (src_data[dst_data[j]] < src_data[dst_data[k]]) { + int32_t tmp = dst_data[j]; + dst_data[j] = dst_data[k]; + dst_data[k] = tmp; + } + } + } + float max_value = src_data[dst_data[0]]; + //printf("Row %ld: max_value is %g, next is %g\n", i, (double)max_value, (double)src_data[dst_data[1]]); + for (int j = min_entries; j < ne0; ++j) { + if (src_data[dst_data[j]] < max_value*thresh) { + //printf(" row %ld: turning off expert %d(%d) with value %g\n", i, j, dst_data[j], (double)src_data[dst_data[j]]); + dst_data[j] = -1; + } + } + } +} + +static void ggml_compute_forward_argsort_thresh( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_argsort_thresh_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16( @@ -19476,6 +19600,10 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_argsort(params, tensor); } break; + case GGML_OP_ARGSORT_THRESH: + { + ggml_compute_forward_argsort_thresh(params, tensor); + } break; case GGML_OP_LEAKY_RELU: { ggml_compute_forward_leaky_relu(params, tensor); @@ -20461,6 +20589,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_OP_ARGSORT_THRESH: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_LEAKY_RELU: { GGML_ABORT("fatal error"); // TODO: not implemented @@ -21181,6 +21313,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: + case GGML_OP_ARGSORT_THRESH: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: diff --git a/include/llama.h b/include/llama.h index bb43aebc..38a12744 100644 --- a/include/llama.h +++ b/include/llama.h @@ -386,6 +386,8 @@ extern "C" { int mla_attn; // whether to use MLA attention [EXPERIMENTAL] int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL] bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL] + int min_experts; + float thresh_experts; // Abort callback // if it returns true, execution of llama_decode() will be aborted diff --git a/src/llama.cpp b/src/llama.cpp index 0dcc78dc..3a8b54ca 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2513,6 +2513,8 @@ struct llama_cparams { int mla_attn; int attn_max_batch; bool fused_moe_up_gate; + int min_experts; + float thresh_experts; enum llama_pooling_type pooling_type; @@ -8631,7 +8633,8 @@ llm_expert_gating_func_type gating_op, } // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used, + lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts, "ffn_moe_topk", il); @@ -8974,6 +8977,8 @@ struct llm_build_context { const int mla_attn; const int attn_max_batch; const bool fused_moe_up_gate; + const int min_experts; + const float thresh_experts; const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -9027,6 +9032,8 @@ struct llm_build_context { mla_attn (cparams.mla_attn), attn_max_batch (cparams.attn_max_batch), fused_moe_up_gate(cparams.fused_moe_up_gate), + min_experts (cparams.min_experts), + thresh_experts (cparams.thresh_experts), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -17725,6 +17732,8 @@ struct llama_context_params llama_context_default_params() { /*.mla_attn =*/ 0, /*.attn_max_batch =*/ 0, /*.fused_moe_up_gate =*/ false, + /*.min_experts =*/ -1, + /*.thtesh_experts =*/ 0.0f, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -17926,6 +17935,9 @@ struct llama_context * llama_new_context_with_model( cparams.mla_attn = params.mla_attn; cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate= params.fused_moe_up_gate; + cparams.min_experts = params.min_experts; + cparams.thresh_experts = params.thresh_experts; + cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -17995,6 +18007,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch); LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); + LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); |