diff options
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 9 |
1 files changed, 9 insertions, 0 deletions
@@ -1393,6 +1393,9 @@ struct llama_cparams { bool mul_mat_q; bool offload_kqv; + + ggml_backend_sched_eval_callback cb_eval; + void * cb_eval_user_data; }; struct llama_layer { @@ -6254,6 +6257,7 @@ static int llama_decode_internal( //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); ggml_backend_sched_reset(lctx.sched); + ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); ggml_cgraph * gf = llama_build_graph(lctx, batch); @@ -9276,6 +9280,8 @@ struct llama_context_params llama_context_default_params() { /*.yarn_beta_fast =*/ 32.0f, /*.yarn_beta_slow =*/ 1.0f, /*.yarn_orig_ctx =*/ 0, + /*.cb_eval =*/ nullptr, + /*.cb_eval_user_data =*/ nullptr, /*.type_k =*/ GGML_TYPE_F16, /*.type_v =*/ GGML_TYPE_F16, /*.mul_mat_q =*/ true, @@ -9416,6 +9422,9 @@ struct llama_context * llama_new_context_with_model( hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx : hparams.n_ctx_train; + cparams.cb_eval = params.cb_eval; + cparams.cb_eval_user_data = params.cb_eval_user_data; + auto rope_scaling_type = params.rope_scaling_type; if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) { rope_scaling_type = hparams.rope_scaling_type_train; |