summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-03-02 13:47:38 +0200
committerGitHub <noreply@github.com>2025-03-02 13:47:38 +0200
commita89adaa78f505675be7be6180f419b4b0158c15a (patch)
treead82fa3ad44f66f37885bdf0d0d025166eff9535
parentef9a3d17b52bb5f6d55f7ef7e05e41e22f2ad81d (diff)
SER - Smart Expert Reduction (#239)
* A better way to measure the cost of ggml_barrier * Smart expert selection * Add ser option to llama-bench --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--common/common.cpp35
-rw-r--r--common/common.h2
-rw-r--r--examples/llama-bench/llama-bench.cpp65
-rw-r--r--ggml/include/ggml.h13
-rw-r--r--ggml/src/ggml-cuda.cu16
-rw-r--r--ggml/src/ggml-cuda/argsort.cu47
-rw-r--r--ggml/src/ggml-cuda/argsort.cuh2
-rw-r--r--ggml/src/ggml-cuda/getrows.cu17
-rw-r--r--ggml/src/ggml.c143
-rw-r--r--include/llama.h2
-rw-r--r--src/llama.cpp15
11 files changed, 330 insertions, 27 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 5c9070da..e62944b9 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -322,6 +322,26 @@ bool parse_buft_overrides(const std::string& value, std::vector<llama_model_tens
}
return true;
}
+template<class T1, class T2>
+std::vector<std::pair<T1,T2>> string_split_pairs(const std::string & str, char delim) {
+ std::vector<std::pair<T1,T2>> values;
+ std::istringstream str_stream(str);
+ std::string token;
+ T1 first_value;
+ int i = 0;
+ while (std::getline(str_stream, token, delim)) {
+ std::istringstream token_stream(token);
+ if (i%2 == 0) {
+ token_stream >> first_value;
+ } else {
+ T2 value;
+ token_stream >> value;
+ values.emplace_back(first_value, value);
+ }
+ i++;
+ }
+ return values;
+}
}
#define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; }
@@ -864,6 +884,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.fused_moe_up_gate = true;
return true;
}
+ if (arg == "-ser" || arg == "--smart-expert-reduction") {
+ CHECK_ARG
+ auto values = string_split_pairs<int,float>(argv[i], ',');
+ if (values.size() == 1) {
+ params.min_experts = values.front().first;
+ params.thresh_experts = values.front().second;
+ } else {
+ invalid_param = true;
+ }
+ return true;
+ }
if (arg == "-co" || arg == "--color") {
params.use_color = true;
return true;
@@ -1523,6 +1554,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
+ options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
"in conversation mode, this will be used as system prompt\n"
"(default: '%s')", params.prompt.c_str() });
@@ -2368,6 +2400,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.mla_attn = params.mla_attn;
cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate = params.fused_moe_up_gate;
+ cparams.min_experts = params.min_experts;
+ cparams.thresh_experts = params.thresh_experts;
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
@@ -3368,6 +3402,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch);
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
+ fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts);
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
diff --git a/common/common.h b/common/common.h
index f35f3558..f6a55885 100644
--- a/common/common.h
+++ b/common/common.h
@@ -178,6 +178,8 @@ struct gpt_params {
int mla_attn = 0; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache
int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false)
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models
+ int min_experts = -1;
+ float thresh_experts = 0;
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool ignore_eos = false; // ignore generated EOS tokens
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index a08cb762..167525bc 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -215,6 +215,9 @@ static std::string pair_str(const std::pair<int, int> & p) {
return buf;
}
+// Ser = Smart Expert Reduction
+using Ser = std::pair<int,float>;
+
struct cmd_params {
std::vector<std::string> model;
std::vector<int> n_prompt;
@@ -234,6 +237,7 @@ struct cmd_params {
std::vector<bool> flash_attn;
std::vector<int> mla_attn;
std::vector<int> attn_max_batch;
+ std::vector<Ser> ser;
std::vector<std::vector<float>> tensor_split;
std::vector<bool> use_mmap;
std::vector<bool> embeddings;
@@ -267,6 +271,7 @@ static const cmd_params cmd_params_defaults = {
/* flash_attn */ {false},
/* mla_attn */ {0},
/* attn_max_batch */ {0},
+ /* ser */ {{-1,0.0f}},
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* use_mmap */ {true},
/* embeddings */ {false},
@@ -304,6 +309,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
printf(" -amb, --attn-max-batch <i> (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str());
+ printf(" -ser, --smart-expert-reduction <i,f>(default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
@@ -387,6 +393,28 @@ bool parse_buft_overrides(const std::string& value, std::vector<llama_model_tens
}
return true;
}
+template<class T1, class T2>
+std::vector<std::pair<T1,T2>> string_split_pairs(const std::string & str, char delim) {
+ std::vector<std::pair<T1,T2>> values;
+ std::istringstream str_stream(str);
+ std::string token;
+ T1 first_value;
+ int i = 0;
+ while (std::getline(str_stream, token, delim)) {
+ std::istringstream token_stream(token);
+ if (i%2 == 0) {
+ token_stream >> first_value;
+ if (token_stream.fail()) return {};
+ } else {
+ T2 value;
+ token_stream >> value;
+ if (token_stream.fail()) return {};
+ values.emplace_back(first_value, value);
+ }
+ i++;
+ }
+ return values;
+}
}
static cmd_params parse_cmd_params(int argc, char ** argv) {
@@ -588,6 +616,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = string_split<int>(argv[i], split_delim);
params.attn_max_batch.insert(params.attn_max_batch.end(), p.begin(), p.end());
+ } else if (arg == "-ser" || arg == "--smart-expert-reduction") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ auto p = string_split_pairs<int,float>(argv[i], split_delim);
+ params.ser.insert(params.ser.end(), p.begin(), p.end());
} else if (arg == "-mmp" || arg == "--mmap") {
if (++i >= argc) {
invalid_param = true;
@@ -701,6 +736,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; }
if (params.attn_max_batch.empty()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch; }
+ if (params.ser.empty()) { params.ser = cmd_params_defaults.ser; }
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
@@ -739,6 +775,7 @@ struct cmd_params_instance {
bool flash_attn;
int mla_attn;
int attn_max_batch;
+ Ser ser;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -787,6 +824,8 @@ struct cmd_params_instance {
cparams.mla_attn = mla_attn;
cparams.attn_max_batch = attn_max_batch;
cparams.fused_moe_up_gate = fmoe;
+ cparams.min_experts = ser.first;
+ cparams.thresh_experts = ser.second;
cparams.embeddings = embeddings;
return cparams;
@@ -813,6 +852,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
for (const auto & fa : params.flash_attn)
for (const auto & mla : params.mla_attn)
for (const auto & amb : params.attn_max_batch)
+ for (const auto & ser : params.ser)
for (const auto & nt : params.n_threads) {
for (const auto & n_prompt : params.n_prompt) {
if (n_prompt == 0) {
@@ -836,6 +876,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
+ /* .ser = */ ser,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -868,6 +909,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
+ /* .ser = */ ser,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -900,6 +942,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
+ /* .ser = */ ser,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -932,6 +975,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
+ /* .ser = */ ser,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -975,6 +1019,7 @@ struct test {
bool flash_attn;
int mla_attn;
int attn_max_batch;
+ Ser ser;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -1007,6 +1052,7 @@ struct test {
flash_attn = inst.flash_attn;
mla_attn = inst.mla_attn;
attn_max_batch = inst.attn_max_batch;
+ ser = inst.ser;
tensor_split = inst.tensor_split;
use_mmap = inst.use_mmap;
embeddings = inst.embeddings;
@@ -1101,7 +1147,7 @@ struct test {
"n_batch", "n_ubatch",
"n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
- "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch",
+ "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser",
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
@@ -1149,6 +1195,11 @@ struct test {
tensor_split_str += "/";
}
}
+ auto ser_to_string = [] (const Ser& ser) {
+ std::ostringstream str;
+ str << ser.first << ',' << ser.second;
+ return str.str();
+ };
std::vector<std::string> values = {
build_commit, std::to_string(build_number),
std::to_string(cuda), std::to_string(vulkan), std::to_string(vulkan),
@@ -1158,7 +1209,8 @@ struct test {
std::to_string(n_batch), std::to_string(n_ubatch),
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), split_mode_str(split_mode),
- std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch),
+ std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
+ std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser),
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
@@ -1328,6 +1380,9 @@ struct markdown_printer : public printer {
if (field == "attn_max_batch") {
return 5;
}
+ if (field == "ser") {
+ return 10;
+ }
if (field == "use_mmap") {
return 4;
}
@@ -1371,6 +1426,9 @@ struct markdown_printer : public printer {
if (field == "attn_max_batch") {
return "amb";
}
+ if (field == "attn_max_batch") {
+ return "ser";
+ }
if (field == "use_mmap") {
return "mmap";
}
@@ -1432,6 +1490,9 @@ struct markdown_printer : public printer {
if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn) {
fields.emplace_back("attn_max_batch");
}
+ if (params.ser.size() > 1 || params.ser != cmd_params_defaults.ser) {
+ fields.emplace_back("ser");
+ }
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
fields.emplace_back("tensor_split");
}
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index d12b90d0..91219d4a 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -597,6 +597,7 @@ extern "C" {
GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
+ GGML_OP_ARGSORT_THRESH,
GGML_OP_LEAKY_RELU,
GGML_OP_SOFTCAP,
GGML_OP_SOFT_CAP_MAX,
@@ -1913,6 +1914,12 @@ extern "C" {
struct ggml_tensor * a,
enum ggml_sort_order order);
+ GGML_API struct ggml_tensor * ggml_argsort_thresh(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int min_entries,
+ float threshold);
+
GGML_API struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
@@ -1924,6 +1931,12 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
+ GGML_API struct ggml_tensor * ggml_top_k_thresh(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int k,
+ int min_entries,
+ float thresh);
#define GGML_KQ_MASK_PAD 32
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index bc960678..85df0694 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2133,7 +2133,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
for (int64_t id = 0; id < n_ids; id++) {
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
- GGML_ASSERT(i02 >= 0 && i02 < n_as);
+ if (i02 < 0 || i02 >= n_as) continue;
+ //GGML_ASSERT(i02 >= 0 && i02 < n_as);
const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
@@ -2162,7 +2163,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
- GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
+ if (i02 < 0 || i02 >= n_as) continue;
+ //GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
if (row_id_i != i02) {
continue;
@@ -2301,7 +2303,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
for (int64_t id = 0; id < n_ids; id++) {
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
- GGML_ASSERT(i02 >= 0 && i02 < n_as);
+ if (i02 < 0 || i02 >= n_as) continue;
+ //GGML_ASSERT(i02 >= 0 && i02 < n_as);
const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
@@ -2362,7 +2365,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
- GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
+ if (row_id_i < 0 || row_id_i >= n_as) continue;
+ //GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
if (row_id_i != i02) {
continue;
@@ -2637,6 +2641,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst);
break;
+ case GGML_OP_ARGSORT_THRESH:
+ ggml_cuda_op_argsort_thresh(ctx, dst);
+ break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_cuda_flash_attn_ext(ctx, dst);
break;
@@ -3252,6 +3259,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
+ case GGML_OP_ARGSORT_THRESH:
case GGML_OP_ACC:
case GGML_OP_GROUP_NORM:
case GGML_OP_UPSCALE:
diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu
index 607ded85..1734b771 100644
--- a/ggml/src/ggml-cuda/argsort.cu
+++ b/ggml/src/ggml-cuda/argsort.cu
@@ -8,7 +8,8 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
}
template<ggml_sort_order order>
-static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
+static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad,
+ int min_experts, float thresh_experts) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
@@ -51,9 +52,18 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
}
}
- // copy the result to dst without the padding
- if (col < ncols) {
- dst[row * ncols + col] = dst_row[col];
+ if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) {
+ __syncthreads();
+ float max_val = x_row[dst_row[0]];
+ if (col < ncols) {
+ dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1;
+ }
+ }
+ else {
+ // copy the result to dst without the padding
+ if (col < ncols) {
+ dst[row * ncols + col] = dst_row[col];
+ }
}
}
@@ -65,7 +75,8 @@ static int next_power_of_2(int x) {
return n;
}
-static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
+static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows,
+ ggml_sort_order order, int min_experts, float thresh_experts, cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);
@@ -77,9 +88,9 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
if (order == GGML_SORT_ORDER_ASC) {
- k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+ k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
} else if (order == GGML_SORT_ORDER_DESC) {
- k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+ k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts);
} else {
GGML_ABORT("fatal error");
}
@@ -100,5 +111,25 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
- argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
+ argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, -1, 0.f, stream);
+}
+
+void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ int min_experts = dst->op_params[0];
+ float thresh;
+ memcpy(&thresh, dst->op_params + 1, sizeof(float));
+
+ argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, GGML_SORT_ORDER_DESC, min_experts, thresh, stream);
}
diff --git a/ggml/src/ggml-cuda/argsort.cuh b/ggml/src/ggml-cuda/argsort.cuh
index 68a00154..4bafa2d7 100644
--- a/ggml/src/ggml-cuda/argsort.cuh
+++ b/ggml/src/ggml-cuda/argsort.cuh
@@ -1,3 +1,5 @@
#include "common.cuh"
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu
index 4c370323..973b6526 100644
--- a/ggml/src/ggml-cuda/getrows.cu
+++ b/ggml/src/ggml-cuda/getrows.cu
@@ -4,7 +4,7 @@
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void k_get_rows(
const void * src0, const int32_t * src1, dst_t * dst,
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+ int64_t ne00, int64_t ne01, /*int64_t ne02, int64_t ne03,*/
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
@@ -31,7 +31,11 @@ static __global__ void k_get_rows(
// dequantize
dfloat2 v;
- dequantize_kernel(src0_row, ib, iqs, v);
+ if (i01 >= 0 && i01 < ne01) {
+ dequantize_kernel(src0_row, ib, iqs, v);
+ } else {
+ v.x = v.y = 0;
+ }
dst_row[iybs + iqs + 0] = v.x;
dst_row[iybs + iqs + y_offset] = v.y;
@@ -40,7 +44,7 @@ static __global__ void k_get_rows(
template<typename src0_t, typename dst_t>
static __global__ void k_get_rows_float(
const src0_t * src0, const int32_t * src1, dst_t * dst,
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+ int64_t ne00, int64_t ne01, /*int64_t ne02, int64_t ne03,*/
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
@@ -56,11 +60,10 @@ static __global__ void k_get_rows_float(
}
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
-
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
- dst_row[i00] = src0_row[i00];
+ dst_row[i00] = i01 >= 0 && i01 < ne01 ? dst_t(src0_row[i00]) : dst_t(0);
}
template<int qk, int qr, dequantize_kernel_t dq>
@@ -88,7 +91,7 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd,
- ne00, /*ne01, ne02, ne03,*/
+ ne00, ne01, /*ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
@@ -120,7 +123,7 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd,
- ne00, /*ne01, ne02, ne03,*/
+ ne00, ne01, /*ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 7ba5e1ad..31fbc57e 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -3875,6 +3875,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"ARANGE",
"TIMESTEP_EMBEDDING",
"ARGSORT",
+ "ARGSORT_THRESH",
"LEAKY_RELU",
"SOFTCAP",
"SOFT_CAP_MAX",
@@ -3905,7 +3906,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -3969,6 +3970,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"arange(start, stop, step)",
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
+ "argsort_thresh(x)",
"leaky_relu(x)",
"k2*tanh(k1*x)",
"soft_max(k2*tanh(k1*x))",
@@ -3999,7 +4001,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -8497,6 +8499,27 @@ struct ggml_tensor * ggml_argsort(
return result;
}
+// ggml_argsort
+
+struct ggml_tensor * ggml_argsort_thresh(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int min_entries,
+ float thresh) {
+ bool is_node = false;
+
+ //printf("%s: min_entries = %d, thresh = %g\n", __func__, min_entries, (double)thresh);
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
+
+ ggml_set_op_params_i32(result, 0, (int32_t) min_entries);
+ ggml_set_op_params_f32(result, 1, thresh);
+
+ result->op = GGML_OP_ARGSORT_THRESH;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
// ggml_top_k
@@ -8516,6 +8539,32 @@ struct ggml_tensor * ggml_top_k(
return result;
}
+// ggml_top_k_thresh
+
+struct ggml_tensor * ggml_top_k_thresh(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int k,
+ int min_entries,
+ float thresh) {
+ GGML_ASSERT(a->ne[0] >= k);
+
+ //printf("%s: k = %d, min_entries = %d, thresh = %g\n", __func__, k, min_entries, (double)thresh);
+ struct ggml_tensor * result;
+ if (min_entries <= 0 || thresh <= 0) {
+ result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
+ } else {
+ result = ggml_argsort_thresh(ctx, a, min_entries, thresh);
+ }
+
+ result = ggml_view_4d(ctx, result,
+ k, result->ne[1], result->ne[2], result->ne[3],
+ result->nb[1], result->nb[2], result->nb[3],
+ 0);
+
+ return result;
+}
+
// ggml_flash_attn_ext
struct ggml_tensor * ggml_flash_attn_ext(
@@ -14485,7 +14534,8 @@ static void ggml_compute_forward_mul_mat_id(
for (int id = 0; id < n_ids; ++id) {
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
- assert(i02 >= 0 && i02 < n_as);
+ if (i02 < 0 || i02 >= n_as) continue;
+ //assert(i02 >= 0 && i02 < n_as);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
matrix_row_counts[i02] += 1;
@@ -14737,7 +14787,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
for (int id = 0; id < n_ids; ++id) {
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
- assert(i02 >= 0 && i02 < n_as);
+ if (i02 < 0 || i02 >= n_as) continue;
+ //assert(i02 >= 0 && i02 < n_as);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
matrix_row_counts[i02] += 1;
@@ -15580,7 +15631,11 @@ static void ggml_compute_forward_get_rows_q(
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
- assert(i01 >= 0 && i01 < ne01);
+ if (i01 < 0 || i01 >= ne01) {
+ memset((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float));
+ continue;
+ }
+ //assert(i01 >= 0 && i01 < ne01);
dequantize_row_q(
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -17667,6 +17722,75 @@ static void ggml_compute_forward_argsort(
}
}
+// ggml_compute_forward_argsort_thresh
+
+static void ggml_compute_forward_argsort_thresh_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT(nb0 == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t nr = ggml_nrows(src0);
+
+ int min_entries = ggml_get_op_params_i32(dst, 0);
+ float thresh = ggml_get_op_params_f32(dst, 1);
+
+ //if (ith == 0) printf("%s: min_entries = %d, thresh = %g\n", __func__, min_entries, (double)thresh);
+
+ for (int64_t i = ith; i < nr; i += nth) {
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
+
+ for (int64_t j = 0; j < ne0; j++) {
+ dst_data[j] = j;
+ }
+
+ // C doesn't have a functional sort, so we do a bubble sort instead
+ for (int64_t j = 0; j < ne0; j++) {
+ for (int64_t k = j + 1; k < ne0; k++) {
+ if (src_data[dst_data[j]] < src_data[dst_data[k]]) {
+ int32_t tmp = dst_data[j];
+ dst_data[j] = dst_data[k];
+ dst_data[k] = tmp;
+ }
+ }
+ }
+ float max_value = src_data[dst_data[0]];
+ //printf("Row %ld: max_value is %g, next is %g\n", i, (double)max_value, (double)src_data[dst_data[1]]);
+ for (int j = min_entries; j < ne0; ++j) {
+ if (src_data[dst_data[j]] < max_value*thresh) {
+ //printf(" row %ld: turning off expert %d(%d) with value %g\n", i, j, dst_data[j], (double)src_data[dst_data[j]]);
+ dst_data[j] = -1;
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_argsort_thresh(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_argsort_thresh_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_flash_attn_ext
static void ggml_compute_forward_flash_attn_ext_f16(
@@ -19476,6 +19600,10 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_argsort(params, tensor);
} break;
+ case GGML_OP_ARGSORT_THRESH:
+ {
+ ggml_compute_forward_argsort_thresh(params, tensor);
+ } break;
case GGML_OP_LEAKY_RELU:
{
ggml_compute_forward_leaky_relu(params, tensor);
@@ -20461,6 +20589,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
+ case GGML_OP_ARGSORT_THRESH:
+ {
+ GGML_ABORT("fatal error"); // TODO: not implemented
+ }
case GGML_OP_LEAKY_RELU:
{
GGML_ABORT("fatal error"); // TODO: not implemented
@@ -21181,6 +21313,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
+ case GGML_OP_ARGSORT_THRESH:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV:
diff --git a/include/llama.h b/include/llama.h
index bb43aebc..38a12744 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -386,6 +386,8 @@ extern "C" {
int mla_attn; // whether to use MLA attention [EXPERIMENTAL]
int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL]
bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL]
+ int min_experts;
+ float thresh_experts;
// Abort callback
// if it returns true, execution of llama_decode() will be aborted
diff --git a/src/llama.cpp b/src/llama.cpp
index 0dcc78dc..3a8b54ca 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -2513,6 +2513,8 @@ struct llama_cparams {
int mla_attn;
int attn_max_batch;
bool fused_moe_up_gate;
+ int min_experts;
+ float thresh_experts;
enum llama_pooling_type pooling_type;
@@ -8631,7 +8633,8 @@ llm_expert_gating_func_type gating_op,
}
// select experts
- ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
+ ggml_tensor * selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used,
+ lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
cb(selected_experts, "ffn_moe_topk", il);
@@ -8974,6 +8977,8 @@ struct llm_build_context {
const int mla_attn;
const int attn_max_batch;
const bool fused_moe_up_gate;
+ const int min_experts;
+ const float thresh_experts;
const enum llama_pooling_type pooling_type;
const enum llama_rope_type rope_type;
@@ -9027,6 +9032,8 @@ struct llm_build_context {
mla_attn (cparams.mla_attn),
attn_max_batch (cparams.attn_max_batch),
fused_moe_up_gate(cparams.fused_moe_up_gate),
+ min_experts (cparams.min_experts),
+ thresh_experts (cparams.thresh_experts),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
cb (cb),
@@ -17725,6 +17732,8 @@ struct llama_context_params llama_context_default_params() {
/*.mla_attn =*/ 0,
/*.attn_max_batch =*/ 0,
/*.fused_moe_up_gate =*/ false,
+ /*.min_experts =*/ -1,
+ /*.thtesh_experts =*/ 0.0f,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
};
@@ -17926,6 +17935,9 @@ struct llama_context * llama_new_context_with_model(
cparams.mla_attn = params.mla_attn;
cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate= params.fused_moe_up_gate;
+ cparams.min_experts = params.min_experts;
+ cparams.thresh_experts = params.thresh_experts;
+
cparams.pooling_type = params.pooling_type;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
@@ -17995,6 +18007,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch);
LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
+ LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);