summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-05-12 07:47:46 +0300
committerGitHub <noreply@github.com>2025-05-12 07:47:46 +0300
commit8669c3db2b98f05775292778dd05f424ee0cd250 (patch)
treeed5c6a41e81ecd6b6620b748bfd765997663eb4c
parent504fb890d90ec27e5f4822b7bd772fa94d4d6aac (diff)
GPU offload policy (#405)
* Adding GPU offload policy * Minor --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--common/common.cpp17
-rw-r--r--common/common.h1
-rw-r--r--ggml/include/ggml-backend.h3
-rw-r--r--ggml/src/ggml-backend.c30
-rw-r--r--ggml/src/ggml-cuda.cu4
-rw-r--r--include/llama.h3
-rw-r--r--src/llama.cpp21
7 files changed, 77 insertions, 2 deletions
diff --git a/common/common.cpp b/common/common.cpp
index f0c618e0..ab936ee7 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1213,6 +1213,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
return true;
}
+ if (arg == "--offload-policy" || arg == "-op") {
+ CHECK_ARG
+ auto p = string_split_pairs<int,int>(argv[i], ',');
+ if (p.empty()) {
+ fprintf(stderr, "error: Invalid offload policy argument: %s\n", argv[i]);
+ invalid_param = true;
+ } else {
+ params.offload_policy.insert(params.offload_policy.end(), p.begin(), p.end());
+ }
+ return true;
+ }
if (arg == "--host") {
CHECK_ARG
params.hostname = argv[i];
@@ -2222,6 +2233,10 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
return iparams;
}
+ for (auto [op, on_off] : params.offload_policy) {
+ llama_set_offload_policy(lctx, op, on_off);
+ }
+
if (!params.control_vectors.empty()) {
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model);
@@ -2418,6 +2433,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
+ if (!params.offload_policy.empty()) cparams.offload_policy = (void *)&params.offload_policy;
+
return cparams;
}
diff --git a/common/common.h b/common/common.h
index b4f75236..fd83c9d3 100644
--- a/common/common.h
+++ b/common/common.h
@@ -143,6 +143,7 @@ struct gpt_params {
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<llama_model_kv_override> kv_overrides;
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
+ std::vector<std::pair<int,int>> offload_policy;
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
std::vector<llama_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index 5f3f1e28..2975d43a 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -208,6 +208,9 @@ extern "C" {
// Set a callback to be called for each resulting node during graph compute
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
+ // enable or disable op offload for a given op
+ GGML_API void ggml_backend_sched_set_op_offload(ggml_backend_sched_t sched, enum ggml_op op, bool on_or_off);
+
//
// Utils
//
diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c
index fd538f50..410ab9e5 100644
--- a/ggml/src/ggml-backend.c
+++ b/ggml/src/ggml-backend.c
@@ -1104,9 +1104,34 @@ struct ggml_backend_sched {
char * context_buffer;
size_t context_buffer_size;
+ uint32_t op_offload[(GGML_OP_COUNT + 31)/32];
+
bool debug;
};
+void ggml_backend_sched_set_op_offload(ggml_backend_sched_t sched, enum ggml_op op, bool on_or_off) {
+ int int_op = (int)op;
+ if (!sched) return;
+ if (int_op < 0 || int_op >= (int)GGML_OP_COUNT) {
+ uint32_t mask = on_or_off ? 0xffffffff : 0;
+ for (int i = 0; i < (GGML_OP_COUNT + 31)/32; ++i) sched->op_offload[i] = mask;
+ return;
+ }
+ int i = int_op >> 5;
+ int j = int_op & 31;
+ if (on_or_off) {
+ sched->op_offload[i] |= (1u << j);
+ } else {
+ sched->op_offload[i] &= (~(1u << j));
+ }
+}
+
+static inline bool ggml_backend_sched_offload_enabled(ggml_backend_sched_t sched, enum ggml_op op) {
+ int int_op = (int)op;
+ if (!sched || op < 0 || op >= GGML_OP_COUNT) return false;
+ return sched->op_offload[int_op >> 5] & (1u << (int_op & 31));
+}
+
#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
#define tensor_backend_id(tensor) sched->hv_tensor_backend_ids[hash_id(tensor)]
#define tensor_id_copy(id, backend_id, copy_id) sched->hv_tensor_copies[(id) * sched->n_backends * sched->n_copies + (backend_id) * sched->n_copies + (copy_id)]
@@ -1181,6 +1206,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
}
// operations with weights are preferably run on the same backend as the weights
+ bool offload_enabled = ggml_backend_sched_offload_enabled(sched, tensor->op);
for (int i = 0; i < GGML_MAX_SRC; i++) {
const struct ggml_tensor * src = tensor->src[i];
if (src == NULL) {
@@ -1189,7 +1215,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
// check if a backend with higher prio wants to offload the op
- if (src_backend_id == sched->n_backends - 1) {
+ if (offload_enabled && src_backend_id == sched->n_backends - 1) {
for (int b = 0; b < src_backend_id; b++) {
if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
SET_CAUSE(tensor, "1.off");
@@ -1888,6 +1914,8 @@ ggml_backend_sched_t ggml_backend_sched_new(
struct ggml_backend_sched * sched = calloc(1, sizeof(struct ggml_backend_sched));
+ for (int i = 0; i < (GGML_OP_COUNT + 31)/32; ++i) sched->op_offload[i] = 0xffffffff;
+
sched->debug = getenv("GGML_SCHED_DEBUG") != NULL;
sched->n_backends = n_backends;
sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 87f80d0c..ef73ee7d 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -3391,6 +3391,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = op->op == GGML_OP_MOE_FUSED_UP_GATE ? op->src[2] : op->src[1];
if (op->op == GGML_OP_MOE_FUSED_UP_GATE && a->type != op->src[1]->type) {
+ printf("%s: returning false for GGML_OP_MOE_FUSED_UP_GATE because src0->type != src1->type\n", __func__);
return false;
}
//==================================================================
@@ -3399,6 +3400,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
//}
//==================================================================
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16 && !ggml_is_quantized(a->type)) {
+ printf("%s: returning false for op %d because (case 1)\n", __func__, (int)op->op);
return false;
}
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
@@ -3621,7 +3623,7 @@ GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const
const int min_batch_size = 32;
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
- (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
+ (op->ne[2] >= min_batch_size && (op->op == GGML_OP_MUL_MAT_ID || op->op == GGML_OP_MOE_FUSED_UP_GATE));
GGML_UNUSED(backend);
}
diff --git a/include/llama.h b/include/llama.h
index e2901861..f1511548 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -408,6 +408,7 @@ extern "C" {
// currently works only with CPU execution
ggml_abort_callback abort_callback;
void * abort_callback_data;
+ void * offload_policy;
};
// model quantization parameters
@@ -523,6 +524,8 @@ extern "C" {
struct llama_model * model,
struct llama_context_params params);
+ LLAMA_API void llama_set_offload_policy(struct llama_context * lctx, int op, bool on_or_off);
+
// Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx);
diff --git a/src/llama.cpp b/src/llama.cpp
index d0f76c49..38a2b299 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -19980,6 +19980,7 @@ struct llama_context_params llama_context_default_params() {
/*.thtesh_experts =*/ 0.0f,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
+ /*.offload_policy =*/ nullptr,
};
return result;
@@ -20574,6 +20575,19 @@ struct llama_context * llama_new_context_with_model(
}
}
+ if (params.offload_policy) {
+ const std::vector<std::pair<int, int>>& policy = *(const std::vector<std::pair<int, int>>*)params.offload_policy;
+ for (auto [op, on_off] : policy) {
+ if (op < 0 || op >= int(GGML_OP_COUNT)) {
+ LLAMA_LOG_INFO("XXXXXXXXXXXXXXXXXXXXX Setting offload policy for all ops to %s\n", on_off ? "ON" : "OFF");
+ } else {
+ LLAMA_LOG_INFO("XXXXXXXXXXXXXXXXXXXXX Setting offload policy for op %s to %s\n",
+ ggml_op_name(ggml_op(op)), on_off ? "ON" : "OFF");
+ }
+ ggml_backend_sched_set_op_offload(ctx->sched, ggml_op(op), on_off);
+ }
+ }
+
return ctx;
}
@@ -23222,3 +23236,10 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
fputs(text, stderr);
fflush(stderr);
}
+
+void llama_set_offload_policy(struct llama_context * lctx, int op, bool on_or_off) {
+ if (!lctx || !lctx->sched) return;
+ const char * op_name = op < 0 || op >= int(GGML_OP_COUNT) ? "all ops" : ggml_op_name(ggml_op(op));
+ printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXX offload(%s) = %d\n", op_name, on_or_off);
+ ggml_backend_sched_set_op_offload(lctx->sched, ggml_op(op), on_or_off);
+}