diff options
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 545 |
1 files changed, 255 insertions, 290 deletions
@@ -887,10 +887,10 @@ static void llama_nop(struct ggml_tensor * tensor) { // don't offload by default static std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) { std::vector<char> result(8, 0); - const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size()); + const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_token_to_piece(ctx, token, result.data(), result.size()); + int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); @@ -931,9 +931,9 @@ static const size_t MB = kB*kB; static const size_t GB = kB*kB*kB; struct llama_hparams { + bool vocab_only; uint32_t n_vocab; uint32_t n_ctx_train; // context size the model was trained on - uint32_t n_ctx; // context size used during inference uint32_t n_embd; uint32_t n_head; uint32_t n_head_kv; @@ -944,8 +944,8 @@ struct llama_hparams { float f_norm_eps; float f_norm_rms_eps; - float rope_freq_base; - float rope_freq_scale; + float rope_freq_base_train; + float rope_freq_scale_train; bool operator!=(const llama_hparams & other) const { return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT @@ -962,15 +962,18 @@ struct llama_hparams { uint32_t n_embd_gqa() const { return n_embd/n_gqa(); } +}; - size_t kv_size() const { - size_t result = 2ull; - result *= (size_t) n_embd_gqa(); - result *= (size_t) n_ctx; - result *= (size_t) n_layer; - result *= sizeof(ggml_fp16_t); - return result; - } +struct llama_cparams { + uint32_t n_ctx; // context size used during inference + uint32_t n_batch; + uint32_t n_threads; // number of threads to use for generation + uint32_t n_threads_batch; // number of threads to use for batch processing + + float rope_freq_base; + float rope_freq_scale; + + bool mul_mat_q; }; struct llama_layer { @@ -1148,11 +1151,8 @@ struct llama_model { }; struct llama_context { - llama_context(const llama_model & model) : model(model), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {} + llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {} ~llama_context() { - if (model_owner) { - delete &model; - } #ifdef GGML_USE_METAL if (ctx_metal) { ggml_metal_free(ctx_metal); @@ -1163,27 +1163,26 @@ struct llama_context { } } + llama_cparams cparams; + + const llama_model & model; + + // key + value cache for the self attention + struct llama_kv_cache kv_self; + std::mt19937 rng; bool has_evaluated_once = false; + int64_t t_start_us; + int64_t t_load_us; int64_t t_sample_us = 0; - int64_t t_eval_us = 0; int64_t t_p_eval_us = 0; + int64_t t_eval_us = 0; int32_t n_sample = 0; // number of tokens sampled - int32_t n_eval = 0; // number of eval calls int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) - - const llama_model & model; - - bool model_owner = false; - - int64_t t_load_us; - int64_t t_start_us; - - // key + value cache for the self attention - struct llama_kv_cache kv_self; + int32_t n_eval = 0; // number of eval calls // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector<float> logits; @@ -1218,10 +1217,10 @@ static bool llama_kv_cache_init( const struct llama_hparams & hparams, struct llama_kv_cache & cache, ggml_type wtype, + uint32_t n_ctx, int n_gpu_layers) { const uint32_t n_embd = hparams.n_embd_gqa(); const uint32_t n_layer = hparams.n_layer; - const uint32_t n_ctx = hparams.n_ctx; const int64_t n_mem = n_layer*n_ctx; const int64_t n_elements = n_embd*n_mem; @@ -1255,11 +1254,20 @@ static bool llama_kv_cache_init( (void) n_gpu_layers; #ifdef GGML_USE_CUBLAS + size_t vram_kv_cache = 0; + if (n_gpu_layers > (int)n_layer + 1) { ggml_cuda_assign_buffers_no_scratch(cache.v); + LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__); + vram_kv_cache += ggml_nbytes(cache.v); } if (n_gpu_layers > (int)n_layer + 2) { ggml_cuda_assign_buffers_no_scratch(cache.k); + LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__); + vram_kv_cache += ggml_nbytes(cache.k); + } + if (vram_kv_cache > 0) { + LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0); } #endif // GGML_USE_CUBLAS @@ -1715,7 +1723,7 @@ struct llama_model_loader { lmlock->grow_to(size_lock); } break; -#if defined(GGML_USE_CUBLAS) +#ifdef GGML_USE_CUBLAS case GGML_BACKEND_GPU: case GGML_BACKEND_GPU_SPLIT: // old code: @@ -1748,7 +1756,15 @@ struct llama_model_loader { // load LLaMA models // -static std::string llama_model_ftype_name(enum llama_ftype ftype) { +static std::string llama_model_arch_name(llm_arch arch) { + auto it = LLM_ARCH_NAMES.find(arch); + if (it == LLM_ARCH_NAMES.end()) { + return "unknown"; + } + return it->second; +} + +static std::string llama_model_ftype_name(llama_ftype ftype) { if (ftype & LLAMA_FTYPE_GUESSED) { return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; } @@ -1804,10 +1820,7 @@ static void llm_load_arch(llama_model_loader & ml, llama_model & model) { static void llm_load_hparams( llama_model_loader & ml, - llama_model & model, - int n_ctx, - float rope_freq_base, - float rope_freq_scale) { + llama_model & model) { struct gguf_context * ctx = ml.ctx_gguf; const auto kv = LLM_KV(model.arch); @@ -1818,29 +1831,25 @@ static void llm_load_hparams( GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME)); // get hparams kv - GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST)); - GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH)); - GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH)); - GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH)); - GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT)); - GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT)); + GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST)); + GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH)); + GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH)); + GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH)); + GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT)); + GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT)); // n_head_kv is optional, default to n_head hparams.n_head_kv = hparams.n_head; GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV)); // rope_freq_base (optional) - if (rope_freq_base == 0.0f) { - rope_freq_base = 10000.0f; - GGUF_GET_KEY(ctx, rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); - } + hparams.rope_freq_base_train = 10000.0f; + GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); // rope_freq_scale (inverse of the kv) is optional - if (rope_freq_scale == 0.0f) { - float ropescale = 1.0f; - GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); - rope_freq_scale = 1.0f/ropescale; - } + float ropescale = 1.0f; + GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); + hparams.rope_freq_scale_train = 1.0f/ropescale; // sanity check for n_rot (optional) { @@ -1907,10 +1916,6 @@ static void llm_load_hparams( }; model.ftype = ml.ftype; - - hparams.n_ctx = n_ctx; - hparams.rope_freq_base = rope_freq_base; - hparams.rope_freq_scale = rope_freq_scale; } // TODO: This should probably be in llama.h @@ -2034,31 +2039,30 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { const auto & vocab = model.vocab; // hparams - LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); - LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); - LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix - LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); - LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size()); - LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); - LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); - LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim - LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); - LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); - LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); - LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); - LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); - LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); - LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); + LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); + LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix + LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); + LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size()); + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); + LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim + LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); + LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); + LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); + LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); if (ml.n_bytes < GB) { - LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); + LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); } else { - LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); + LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); } // general kv @@ -2076,13 +2080,9 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { static void llm_load_tensors( llama_model_loader & ml, llama_model & model, - int n_batch, int n_gpu_layers, int main_gpu, const float * tensor_split, - const bool mul_mat_q, - bool low_vram, - ggml_type memory_type, bool use_mlock, llama_progress_callback progress_callback, void * progress_callback_user_data) { @@ -2121,11 +2121,9 @@ static void llm_load_tensors( } (void) main_gpu; - (void) mul_mat_q; -#if defined(GGML_USE_CUBLAS) +#ifdef GGML_USE_CUBLAS LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__); ggml_cuda_set_main_device(main_gpu); - ggml_cuda_set_mul_mat_q(mul_mat_q); #define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT #elif defined(GGML_USE_CLBLAST) @@ -2160,9 +2158,9 @@ static void llm_load_tensors( // norm is not performance relevant on its own but keeping it in VRAM reduces data copying // on Windows however this is detrimental unless everything is on the GPU #ifndef _WIN32 - backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = LLAMA_BACKEND_OFFLOAD; #else - backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #endif // _WIN32 backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; @@ -2226,9 +2224,9 @@ static void llm_load_tensors( // norm is not performance relevant on its own but keeping it in VRAM reduces data copying // on Windows however this is detrimental unless everything is on the GPU #ifndef _WIN32 - backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = LLAMA_BACKEND_OFFLOAD; #else - backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #endif // _WIN32 backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; @@ -2296,9 +2294,9 @@ static void llm_load_tensors( // norm is not performance relevant on its own but keeping it in VRAM reduces data copying // on Windows however this is detrimental unless everything is on the GPU #ifndef _WIN32 - backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = LLAMA_BACKEND_OFFLOAD; #else - backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #endif // _WIN32 backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; @@ -2373,9 +2371,9 @@ static void llm_load_tensors( // norm is not performance relevant on its own but keeping it in VRAM reduces data copying // on Windows however this is detrimental unless everything is on the GPU #ifndef _WIN32 - backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = LLAMA_BACKEND_OFFLOAD; #else - backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #endif // _WIN32 backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; @@ -2447,20 +2445,12 @@ static void llm_load_tensors( // print memory requirements { - const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1; - // this is the total memory required to run the inference size_t mem_required = ctx_size + mmapped_size - vram_weights; // weights in VRAM not in memory - // this is the memory required by one llama_state - const size_t mem_required_state = scale*hparams.kv_size(); - - LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, - mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); - - (void) n_batch; + LLAMA_LOG_INFO("%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); @@ -2469,36 +2459,17 @@ static void llm_load_tensors( if (n_gpu_layers > (int) hparams.n_layer) { LLAMA_LOG_INFO("%s: offloading non-repeating layers to GPU\n", __func__); } - size_t vram_kv_cache = 0; #ifdef GGML_USE_CUBLAS const int max_backend_supported_layers = hparams.n_layer + 3; - const int max_offloadable_layers = low_vram ? hparams.n_layer + 1 : hparams.n_layer + 3; - if (n_gpu_layers > (int) hparams.n_layer + 1) { - if (low_vram) { - LLAMA_LOG_INFO("%s: cannot offload v cache to GPU due to low VRAM option\n", __func__); - } else { - LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__); - vram_kv_cache += hparams.kv_size() / 2; - } - } - if (n_gpu_layers > (int) hparams.n_layer + 2) { - if (low_vram) { - LLAMA_LOG_WARN("%s: cannot offload k cache to GPU due to low VRAM option\n", __func__); - } else { - LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__); - vram_kv_cache += hparams.kv_size() / 2; - } - } + const int max_offloadable_layers = hparams.n_layer + 3; #elif defined(GGML_USE_CLBLAST) const int max_backend_supported_layers = hparams.n_layer + 1; const int max_offloadable_layers = hparams.n_layer + 1; #endif // GGML_USE_CUBLAS - LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", - __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); - LLAMA_LOG_INFO("%s: VRAM used: %zu MB\n", - __func__, (vram_weights + vram_kv_cache + MB - 1) / MB); // round up + LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); + LLAMA_LOG_INFO("%s: VRAM used: %.2f MB\n", __func__, vram_weights / 1024.0 / 1024.0); #else (void) n_gpu_layers; #endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) @@ -2511,7 +2482,7 @@ static void llm_load_tensors( } (void) tensor_split; -#if defined(GGML_USE_CUBLAS) +#ifdef GGML_USE_CUBLAS { ggml_cuda_set_tensor_split(tensor_split); } @@ -2533,29 +2504,24 @@ static void llm_load_tensors( static bool llama_model_load( const std::string & fname, llama_model & model, - int n_ctx, - int n_batch, int n_gpu_layers, int main_gpu, const float * tensor_split, - const bool mul_mat_q, - float rope_freq_base, - float rope_freq_scale, - bool low_vram, - ggml_type memory_type, bool use_mmap, bool use_mlock, bool vocab_only, llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname, use_mmap)); + llama_model_loader ml(fname, use_mmap); - llm_load_arch (*ml, model); - llm_load_hparams(*ml, model, n_ctx, rope_freq_base, rope_freq_scale); - llm_load_vocab (*ml, model); + model.hparams.vocab_only = vocab_only; - llm_load_print_meta(*ml, model); + llm_load_arch (ml, model); + llm_load_hparams(ml, model); + llm_load_vocab (ml, model); + + llm_load_print_meta(ml, model); if (model.hparams.n_vocab != model.vocab.id_to_token.size()) { throw std::runtime_error("vocab size mismatch"); @@ -2567,8 +2533,8 @@ static bool llama_model_load( } llm_load_tensors( - *ml, model, n_batch, n_gpu_layers, - main_gpu, tensor_split, mul_mat_q, low_vram, memory_type, + ml, model, n_gpu_layers, + main_gpu, tensor_split, use_mlock, progress_callback, progress_callback_user_data); } catch (const std::exception & err) { LLAMA_LOG_ERROR("error loading model: %s\n", err.what()); @@ -2583,6 +2549,7 @@ static struct ggml_cgraph * llm_build_llama( const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -2590,7 +2557,7 @@ static struct ggml_cgraph * llm_build_llama( const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head = hparams.n_embd_head(); @@ -2598,8 +2565,8 @@ static struct ggml_cgraph * llm_build_llama( GGML_ASSERT(n_embd_head == hparams.n_rot); - const float freq_base = hparams.rope_freq_base; - const float freq_scale = hparams.rope_freq_scale; + const float freq_base = cparams.rope_freq_base; + const float freq_scale = cparams.rope_freq_scale; const float norm_rms_eps = hparams.f_norm_rms_eps; const int n_gpu_layers = model.n_gpu_layers; @@ -2657,9 +2624,6 @@ static struct ggml_cgraph * llm_build_llama( // offload functions set the tensor output backend to GPU // tensors are GPU-accelerated if any input or the output has been offloaded - // - // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal - // in that case ggml_cuda_assign_buffers has no effect offload_func_t offload_func_nr = llama_nop; // nr = non-repeating offload_func_t offload_func_kq = llama_nop; offload_func_t offload_func_v = llama_nop; @@ -2975,6 +2939,7 @@ static struct ggml_cgraph * llm_build_baichaun( const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -2982,7 +2947,7 @@ static struct ggml_cgraph * llm_build_baichaun( const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head = hparams.n_embd_head(); @@ -2990,8 +2955,8 @@ static struct ggml_cgraph * llm_build_baichaun( GGML_ASSERT(n_embd_head == hparams.n_rot); - const float freq_base = hparams.rope_freq_base; - const float freq_scale = hparams.rope_freq_scale; + const float freq_base = cparams.rope_freq_base; + const float freq_scale = cparams.rope_freq_scale; const float norm_rms_eps = hparams.f_norm_rms_eps; const int n_gpu_layers = model.n_gpu_layers; @@ -3047,9 +3012,6 @@ static struct ggml_cgraph * llm_build_baichaun( // offload functions set the tensor output backend to GPU // tensors are GPU-accelerated if any input or the output has been offloaded - // - // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal - // in that case ggml_cuda_assign_buffers has no effect offload_func_t offload_func_nr = llama_nop; // nr = non-repeating offload_func_t offload_func_kq = llama_nop; offload_func_t offload_func_v = llama_nop; @@ -3382,6 +3344,7 @@ static struct ggml_cgraph * llm_build_falcon( const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -3389,7 +3352,7 @@ static struct ggml_cgraph * llm_build_falcon( const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head = hparams.n_embd_head(); @@ -3397,8 +3360,8 @@ static struct ggml_cgraph * llm_build_falcon( GGML_ASSERT(n_embd_head == hparams.n_rot); - const float freq_base = hparams.rope_freq_base; - const float freq_scale = hparams.rope_freq_scale; + const float freq_base = cparams.rope_freq_base; + const float freq_scale = cparams.rope_freq_scale; const float norm_eps = hparams.f_norm_eps; const int n_gpu_layers = model.n_gpu_layers; @@ -3457,9 +3420,6 @@ static struct ggml_cgraph * llm_build_falcon( // offload functions set the tensor output backend to GPU // tensors are GPU-accelerated if any input or the output has been offloaded - // - // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal - // in that case ggml_cuda_assign_buffers has no effect offload_func_t offload_func_nr = llama_nop; // nr = non-repeating offload_func_t offload_func_kq = llama_nop; offload_func_t offload_func_v = llama_nop; @@ -3753,6 +3713,7 @@ static struct ggml_cgraph * llm_build_starcoder( const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -3760,7 +3721,7 @@ static struct ggml_cgraph * llm_build_starcoder( const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head = hparams.n_embd_head(); @@ -4037,8 +3998,7 @@ static struct ggml_cgraph * llama_build_graph( // static int llama_decode_internal( llama_context & lctx, - llama_batch batch, - int n_threads) { + llama_batch batch) { const uint32_t n_tokens = batch.n_tokens; if (n_tokens == 0) { @@ -4046,6 +4006,15 @@ static int llama_decode_internal( return -1; } + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; + + const auto n_batch = cparams.n_batch; + + GGML_ASSERT(n_tokens <= n_batch); + + int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT const int64_t t_start_us = ggml_time_us(); @@ -4058,9 +4027,6 @@ static int llama_decode_internal( GGML_ASSERT(n_threads > 0); - const auto & model = lctx.model; - const auto & hparams = model.hparams; - auto & kv_self = lctx.kv_self; GGML_ASSERT(!!kv_self.ctx); @@ -4103,7 +4069,7 @@ static int llama_decode_internal( // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important //kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA? - kv_self.n = std::min((int32_t) hparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self))); + kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self))); //printf("kv_self.n = %d\n", kv_self.n); @@ -4128,6 +4094,8 @@ static int llama_decode_internal( ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); } } + + ggml_cuda_set_mul_mat_q(cparams.mul_mat_q); #endif // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -5416,7 +5384,7 @@ void llama_sample_classifier_free_guidance( GGML_ASSERT(ctx); - auto n_vocab = llama_n_vocab(ctx); + auto n_vocab = llama_n_vocab(llama_get_model(ctx)); GGML_ASSERT(n_vocab == (int)candidates->size); GGML_ASSERT(!candidates->sorted); @@ -5445,7 +5413,7 @@ void llama_sample_classifier_free_guidance( llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { GGML_ASSERT(ctx); - auto N = float(llama_n_vocab(ctx)); + auto N = float(llama_n_vocab(llama_get_model(ctx))); int64_t t_start_sample_us; t_start_sample_us = ggml_time_us(); @@ -5632,7 +5600,7 @@ struct llama_logit_info { }; llama_logit_info(llama_context * ctx) : logits(llama_get_logits(ctx)) - , n_vocab(llama_n_vocab(ctx)) + , n_vocab(llama_n_vocab(llama_get_model(ctx))) , max_l(*std::max_element(logits, logits + n_vocab)) , normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l})) { } @@ -5670,7 +5638,6 @@ struct llama_beam_search_data { size_t n_beams; int n_past; int n_predict; - int n_threads; std::vector<llama_beam> beams; std::vector<llama_beam> next_beams; @@ -5680,12 +5647,11 @@ struct llama_beam_search_data { // Used to communicate to/from callback on beams state. std::vector<llama_beam_view> beam_views; - llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads) + llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict) : ctx(ctx) , n_beams(n_beams) , n_past(n_past) , n_predict(n_predict) - , n_threads(n_threads) , beam_views(n_beams) { beams.reserve(n_beams); next_beams.reserve(n_beams); @@ -5722,7 +5688,7 @@ struct llama_beam_search_data { } else { // beam is not at end-of-sentence, so branch with next top_k tokens. if (!beam.tokens.empty()) { - llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0), n_threads); + llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0)); } llama_logit_info logit_info(ctx); std::vector<llama_token_data> next_tokens = logit_info.top_k(n_beams); @@ -5796,7 +5762,7 @@ struct llama_beam_search_data { callback(callback_data, get_beams_state(false)); // Sets common_prefix_length update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed. if (common_prefix_length) { - llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0), n_threads); + llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0)); n_past += common_prefix_length; } // Zero-out next_beam probabilities to place them last in following min-heap. @@ -5837,11 +5803,11 @@ struct llama_beam_search_data { void llama_beam_search(llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, - size_t n_beams, int n_past, int n_predict, int n_threads) { + size_t n_beams, int n_past, int n_predict) { assert(ctx); const int64_t t_start_sample_us = ggml_time_us(); - llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict, n_threads); + llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict); beam_search_data.loop(callback, callback_data); @@ -6061,11 +6027,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s nthread = std::thread::hardware_concurrency(); } - std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname_inp, /*use_mmap*/ false)); + llama_model_loader ml(fname_inp, /*use_mmap*/ false); llama_model model; - llm_load_arch(*ml, model); - llm_load_hparams(*ml, model, 0, 0, 0); + llm_load_arch(ml, model); + llm_load_hparams(ml, model); if (params->only_copy) { ftype = model.ftype; @@ -6075,7 +6041,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s struct gguf_context * ctx_out = gguf_init_empty(); // copy the KV pairs from the input file - gguf_set_kv (ctx_out, ml->ctx_gguf); + gguf_set_kv (ctx_out, ml.ctx_gguf); gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); gguf_set_val_u32(ctx_out, "general.file_type", ftype); @@ -6083,8 +6049,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s int n_attention_wv = 0; int n_feed_forward_w2 = 0; - for (int i = 0; i < ml->n_tensors; ++i) { - struct ggml_tensor * meta = ml->get_tensor_meta(i); + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * meta = ml.get_tensor_meta(i); const std::string name = ggml_get_name(meta); @@ -6120,8 +6086,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s std::vector<no_init<float>> f32_conv_buf; // populate the original tensors so we get an initial meta data - for (int i = 0; i < ml->n_tensors; ++i) { - struct ggml_tensor * meta = ml->get_tensor_meta(i); + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * meta = ml.get_tensor_meta(i); gguf_add_tensor(ctx_out, meta); } @@ -6134,8 +6100,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // placeholder for the meta data ::zeros(fout, meta_size); - for (int i = 0; i < ml->n_tensors; ++i) { - struct ggml_tensor * tensor = ml->get_tensor_meta(i); + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * tensor = ml.get_tensor_meta(i); const std::string name = ggml_get_name(tensor); @@ -6143,10 +6109,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s read_data.resize(ggml_nbytes(tensor)); } tensor->data = read_data.data(); - ml->load_data_for(tensor); + ml.load_data_for(tensor); LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", - ++idx, ml->n_tensors, + ++idx, ml.n_tensors, ggml_get_name(tensor), llama_format_tensor_shape(tensor).c_str(), ggml_type_name(tensor->type)); @@ -6296,7 +6262,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } -// TODO: after the GGUF PR, this likely won't work and needs to be updated static int llama_apply_lora_from_file_internal( const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads ) { @@ -6575,27 +6540,16 @@ static int llama_apply_lora_from_file_internal( // // interface implementation // - -struct llama_context_params llama_context_default_params() { - struct llama_context_params result = { - /*.seed =*/ LLAMA_DEFAULT_SEED, - /*.n_ctx =*/ 512, - /*.n_batch =*/ 512, +struct llama_model_params llama_model_default_params() { + struct llama_model_params result = { /*.n_gpu_layers =*/ 0, /*.main_gpu =*/ 0, /*.tensor_split =*/ nullptr, - /*.rope_freq_base =*/ 0.0f, - /*.rope_freq_scale =*/ 0.0f, /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, - /*.low_vram =*/ false, - /*.mul_mat_q =*/ true, - /*.f16_kv =*/ true, - /*.logits_all =*/ false, /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_mlock =*/ false, - /*.embedding =*/ false, }; #ifdef GGML_USE_METAL @@ -6605,6 +6559,24 @@ struct llama_context_params llama_context_default_params() { return result; } +struct llama_context_params llama_context_default_params() { + struct llama_context_params result = { + /*.seed =*/ LLAMA_DEFAULT_SEED, + /*.n_ctx =*/ 512, + /*.n_batch =*/ 512, + /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default + /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, + /*.rope_freq_base =*/ 0.0f, + /*.rope_freq_scale =*/ 0.0f, + /*.mul_mat_q =*/ true, + /*.f16_kv =*/ true, + /*.logits_all =*/ false, + /*.embedding =*/ false, + }; + + return result; +} + struct llama_model_quantize_params llama_model_quantize_default_params() { struct llama_model_quantize_params result = { /*.nthread =*/ 0, @@ -6660,13 +6632,11 @@ int64_t llama_time_us(void) { struct llama_model * llama_load_model_from_file( const char * path_model, - struct llama_context_params params) { + struct llama_model_params params) { ggml_time_init(); llama_model * model = new llama_model; - ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - unsigned cur_percentage = 0; if (params.progress_callback == NULL) { params.progress_callback_user_data = &cur_percentage; @@ -6683,9 +6653,9 @@ struct llama_model * llama_load_model_from_file( }; } - if (!llama_model_load(path_model, *model, params.n_ctx, params.n_batch, params.n_gpu_layers, - params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale, - params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, + if (!llama_model_load(path_model, *model, params.n_gpu_layers, + params.main_gpu, params.tensor_split, + params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { LLAMA_LOG_ERROR("%s: failed to load model\n", __func__); delete model; @@ -6709,18 +6679,33 @@ struct llama_context * llama_new_context_with_model( llama_context * ctx = new llama_context(*model); + const auto & hparams = model->hparams; + auto & cparams = ctx->cparams; + + cparams.n_batch = params.n_batch; + cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; + cparams.rope_freq_base = params.rope_freq_base == 0 ? hparams.rope_freq_base_train : params.rope_freq_base; + cparams.rope_freq_scale = params.rope_freq_scale == 0 ? hparams.rope_freq_scale_train : params.rope_freq_scale; + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch; + cparams.mul_mat_q = params.mul_mat_q; + if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); } + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); + LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; // reserve memory for context buffers - if (!params.vocab_only) { - if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, params.n_gpu_layers)) { + if (!hparams.vocab_only) { + if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, cparams.n_ctx, model->n_gpu_layers)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -6731,11 +6716,9 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } - const auto & hparams = ctx->model.hparams; - // resized during inference if (params.logits_all) { - ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); + ctx->logits.reserve(cparams.n_ctx*hparams.n_vocab); } else { ctx->logits.reserve(hparams.n_vocab); } @@ -6753,12 +6736,13 @@ struct llama_context * llama_new_context_with_model( ctx->alloc = ggml_allocr_new_measure(tensor_alignment); // build worst-case graph - const uint32_t n_tokens = std::min((int) hparams.n_ctx, params.n_batch); + int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch); + int n_past = cparams.n_ctx - n_tokens; llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, hparams.n_ctx - n_tokens, 0)); + ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0)); #ifdef GGML_USE_METAL - if (params.n_gpu_layers > 0) { + if (model->n_gpu_layers > 0) { ctx->ctx_metal = ggml_metal_init(1); if (!ctx->ctx_metal) { LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__); @@ -6773,7 +6757,7 @@ struct llama_context * llama_new_context_with_model( // measure memory requirements for the graph size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment; - LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); + LLAMA_LOG_INFO("%s: compute buffer total size = %.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); // recreate allocator with exact memory requirements ggml_allocr_free(ctx->alloc); @@ -6786,24 +6770,42 @@ struct llama_context * llama_new_context_with_model( } #endif #ifdef GGML_USE_CUBLAS - if (params.low_vram) { - LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__); - ggml_cuda_set_scratch_size(0); // disable scratch - } else { - ggml_cuda_set_scratch_size(alloc_size); - LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0); + ggml_cuda_set_scratch_size(alloc_size); + LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0); + + // calculate total VRAM usage + auto add_tensor = [](const ggml_tensor * t, size_t & size) { + if (t->backend == GGML_BACKEND_GPU || t->backend == GGML_BACKEND_GPU_SPLIT) { + size += ggml_nbytes(t); + } + }; + size_t model_vram_size = 0; + for (const auto & kv : model->tensors_by_name) { + add_tensor(kv.second, model_vram_size); } + + size_t kv_vram_size = 0; + add_tensor(ctx->kv_self.k, kv_vram_size); + add_tensor(ctx->kv_self.v, kv_vram_size); + + size_t ctx_vram_size = alloc_size + kv_vram_size; + size_t total_vram_size = model_vram_size + ctx_vram_size; + + LLAMA_LOG_INFO("%s: total VRAM used: %.2f MB (model: %.2f MB, context: %.2f MB)\n", __func__, + total_vram_size / 1024.0 / 1024.0, + model_vram_size / 1024.0 / 1024.0, + ctx_vram_size / 1024.0 / 1024.0); #endif } #ifdef GGML_USE_METAL - if (params.n_gpu_layers > 0) { + if (model->n_gpu_layers > 0) { // this allocates all Metal resources and memory buffers void * data_ptr = NULL; size_t data_size = 0; - if (params.use_mmap) { + if (ctx->model.mapping) { data_ptr = ctx->model.mapping->addr; data_size = ctx->model.mapping->size; } else { @@ -6822,11 +6824,8 @@ struct llama_context * llama_new_context_with_model( return NULL; \ } - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size)); - - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.data, ctx->buf_compute.size, 0)); - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0)); - + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size)); + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0)); LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.data, ctx->buf_alloc.size, 0)); #undef LLAMA_METAL_CHECK_BUF } @@ -6850,63 +6849,37 @@ struct llama_context * llama_new_context_with_model( return ctx; } -static struct llama_context * llama_init_from_file( - const char * path_model, - struct llama_context_params params) { - struct llama_model * model = llama_load_model_from_file(path_model, params); - if (!model) { - return nullptr; - } - - struct llama_context * ctx = llama_new_context_with_model(model, params); - ctx->model_owner = true; - - return ctx; -} - void llama_free(struct llama_context * ctx) { delete ctx; } -int llama_n_vocab(const struct llama_context * ctx) { - return llama_model_n_vocab(&ctx->model); +const llama_model * llama_get_model(const struct llama_context * ctx) { + return &ctx->model; } int llama_n_ctx(const struct llama_context * ctx) { - return llama_model_n_ctx(&ctx->model); -} - -int llama_n_ctx_train(const struct llama_context * ctx) { - return llama_model_n_ctx_train(&ctx->model); + return ctx->cparams.n_ctx; } -int llama_n_embd(const struct llama_context * ctx) { - return llama_model_n_embd(&ctx->model); +enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { + return model->vocab.type; } -enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx) { - return ctx->model.vocab.type; -} - -int llama_model_n_vocab(const struct llama_model * model) { +int llama_n_vocab(const struct llama_model * model) { return model->vocab.id_to_token.size(); } -int llama_model_n_ctx(const struct llama_model * model) { - return model->hparams.n_ctx; -} - -int llama_model_n_ctx_train(const struct llama_model * model) { +int llama_n_ctx_train(const struct llama_model * model) { return model->hparams.n_ctx_train; } -int llama_model_n_embd(const struct llama_model * model) { +int llama_n_embd(const struct llama_model * model) { return model->hparams.n_embd; } int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) { return snprintf(buf, buf_size, "%s %s %s", - model->name.c_str(), + llama_model_arch_name(model->arch).c_str(), llama_model_type_name(model->type), llama_model_ftype_name(model->ftype).c_str()); } @@ -7131,9 +7104,11 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat { const auto & kv_self = ctx->kv_self; const auto & hparams = ctx->model.hparams; + const auto & cparams = ctx->cparams; + const int n_layer = hparams.n_layer; const int n_embd = hparams.n_embd_gqa(); - const int n_ctx = hparams.n_ctx; + const int n_ctx = cparams.n_ctx; const size_t kv_size = kv_self.buf.size; const int kv_ntok = kv_self.head; @@ -7239,9 +7214,11 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { { const auto & kv_self = ctx->kv_self; const auto & hparams = ctx->model.hparams; + const auto & cparams = ctx->cparams; + const int n_layer = hparams.n_layer; const int n_embd = hparams.n_embd_gqa(); - const int n_ctx = hparams.n_ctx; + const int n_ctx = cparams.n_ctx; size_t kv_size; int kv_ntok; @@ -7378,11 +7355,10 @@ int llama_eval( struct llama_context * ctx, llama_token * tokens, int32_t n_tokens, - int n_past, - int n_threads) { + int n_past) { llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); - const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads); + const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0)); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } @@ -7394,13 +7370,12 @@ int llama_eval_embd( struct llama_context * ctx, float * embd, int32_t n_tokens, - int n_past, - int n_threads) { + int n_past) { llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, }; - const int ret = llama_decode_internal(*ctx, batch, n_threads); + const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } @@ -7408,6 +7383,11 @@ int llama_eval_embd( return ret; } +void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) { + ctx->cparams.n_threads = n_threads; + ctx->cparams.n_threads_batch = n_threads_batch; +} + struct llama_batch llama_batch_get_one( llama_token * tokens, int32_t n_tokens, @@ -7452,9 +7432,8 @@ void llama_batch_free(struct llama_batch batch) { int llama_decode( struct llama_context * ctx, - struct llama_batch batch, - int n_threads) { - const int ret = llama_decode_internal(*ctx, batch, n_threads); + struct llama_batch batch) { + const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } @@ -7499,16 +7478,6 @@ llama_token llama_token_nl(const struct llama_context * ctx) { } int llama_tokenize( - struct llama_context * ctx, - const char * text, - int text_len, - llama_token * tokens, - int n_max_tokens, - bool add_bos) { - return llama_tokenize_with_model(&ctx->model, text, text_len, tokens, n_max_tokens, add_bos); -} - -int llama_tokenize_with_model( const struct llama_model * model, const char * text, int text_len, @@ -7529,13 +7498,9 @@ int llama_tokenize_with_model( return res.size(); } -int llama_token_to_piece(const struct llama_context * ctx, llama_token token, char * buf, int length) { - return llama_token_to_piece_with_model(&ctx->model, token, buf, length); -} - // does not write null-terminator to buf -int llama_token_to_piece_with_model(const struct llama_model * model, llama_token token, char * buf, int length) { - if (0 <= token && token < llama_model_n_vocab(model)) { +int llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int length) { + if (0 <= token && token < llama_n_vocab(model)) { if (llama_is_normal_token(model->vocab, token)) { std::string result = model->vocab.id_to_token[token].text; if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) { |