diff options
author | Douglas Hanley <thesecretaryofwar@gmail.com> | 2024-03-03 04:40:27 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-03 12:40:27 +0200 |
commit | 475df1d6cf817060028d3ff763cb8097d4ec40d6 (patch) | |
tree | 5cad43f149f24b7b3f40604b78b7971e458aa309 /llama.cpp | |
parent | 87c2e8b2797860a06af3d6c06b8488a8ff1a09ab (diff) |
llama : allow for user specified embedding pooling type (#5849)
* allow for user specified pooling type
* llama : use enum types over int
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 44 |
1 files changed, 28 insertions, 16 deletions
@@ -873,16 +873,16 @@ struct LLM_TN { // gguf helpers // -static const std::map<int32_t, const char *> LLAMA_ROPE_SCALING_TYPES = { +static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_TYPES = { { LLAMA_ROPE_SCALING_TYPE_NONE, "none" }, { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" }, { LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" }, }; -static int32_t llama_rope_scaling_type_from_string(const std::string & name) { +static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) { for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) { if (kv.second == name) { - return kv.first; + return (llama_rope_scaling_type) kv.first; } } @@ -1612,7 +1612,6 @@ struct llama_hparams { float rope_freq_base_train; float rope_freq_scale_train; uint32_t n_yarn_orig_ctx; - int32_t rope_scaling_type_train; float f_clamp_kqv = 0.0f; float f_max_alibi_bias = 0.0f; @@ -1620,8 +1619,9 @@ struct llama_hparams { bool causal_attn = true; bool need_kq_pos = false; - enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; - enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; + enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; + enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; + enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; bool operator!=(const llama_hparams & other) const { if (this->vocab_only != other.vocab_only) return true; @@ -1670,8 +1670,8 @@ struct llama_cparams { uint32_t n_threads; // number of threads to use for generation uint32_t n_threads_batch; // number of threads to use for batch processing - float rope_freq_base; - float rope_freq_scale; + float rope_freq_base; + float rope_freq_scale; uint32_t n_yarn_orig_ctx; // These hyperparameters are not exposed in GGUF, because all @@ -1683,7 +1683,7 @@ struct llama_cparams { float defrag_thold; bool offload_kqv; - bool do_pooling; + enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; @@ -2933,7 +2933,11 @@ template<> bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) { uint32_t tmp; const bool found = get_key(kid, tmp, required); - result = (enum llama_pooling_type) tmp; + if (found) { + result = (enum llama_pooling_type) tmp; + } else { + result = LLAMA_POOLING_TYPE_UNSPECIFIED; + } return found; } @@ -3210,7 +3214,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { case 3: @@ -5175,7 +5179,7 @@ struct llm_build_context { n_kv (worst_case ? n_ctx : kv_self.n), kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), - pooling_type (cparams.do_pooling ? hparams.pooling_type : LLAMA_POOLING_TYPE_NONE), + pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), buf_compute_meta (lctx.buf_compute_meta) { @@ -8015,7 +8019,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); @@ -8043,7 +8047,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { + if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); @@ -11846,6 +11850,7 @@ struct llama_context_params llama_context_default_params() { /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, + /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, @@ -11861,7 +11866,6 @@ struct llama_context_params llama_context_default_params() { /*.logits_all =*/ false, /*.embedding =*/ false, /*.offload_kqv =*/ true, - /*.do_pooling =*/ true, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -12012,7 +12016,7 @@ struct llama_context * llama_new_context_with_model( cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.defrag_thold = params.defrag_thold; cparams.offload_kqv = params.offload_kqv; - cparams.do_pooling = params.do_pooling; + cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; @@ -12038,6 +12042,14 @@ struct llama_context * llama_new_context_with_model( cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; } + if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { + if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { + cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; + } else { + cparams.pooling_type = hparams.pooling_type; + } + } + if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); } |