diff options
Diffstat (limited to 'src/llama.cpp')
-rw-r--r-- | src/llama.cpp | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/src/llama.cpp b/src/llama.cpp index 62ab4d08..68e59758 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9,6 +9,9 @@ #include "ggml-alloc.h" #include "ggml-backend.h" +// TODO: fix this include +#include "iqk/iqk_quantize.h" + #ifdef GGML_USE_RPC # include "ggml-rpc.h" #endif @@ -3653,6 +3656,7 @@ struct llama_model_loader { bool use_mmap = false; bool check_tensors; + bool repack_tensors = false; llama_files files; llama_ftype ftype; @@ -3686,7 +3690,7 @@ struct llama_model_loader { std::string arch_name; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); - llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) { + llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, const struct llama_model_kv_override * param_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -3928,9 +3932,13 @@ struct llama_model_loader { LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__); use_mmap = false; } + if (repack_tensors) { + use_mmap = false; + } this->use_mmap = use_mmap; this->check_tensors = check_tensors; + this->repack_tensors = repack_tensors; } ~llama_model_loader() { @@ -7880,6 +7888,19 @@ static bool llm_load_tensors( } } + if (!ml.use_mmap && ml.repack_tensors) { + int n_repacked = 0; + for (auto& it : model.tensors_by_name) { + if (ggml_backend_buffer_is_host(it.second->buffer)) { + auto orig_type = it.second->type; + iqk_repack_tensor(it.second); + if (it.second->type != orig_type) ++n_repacked; + //printf("Repacking tensor %s\n", it.first.c_str()); + } + } + if (n_repacked > 0) printf("============ Repacked %d tensors\n", n_repacked); + } + if (model.arch == LLM_ARCH_BITNET) { auto set_scale = [] (ggml_tensor * w, ggml_tensor * s) { if (!s) { @@ -7915,7 +7936,7 @@ static bool llm_load_tensors( // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { try { - llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides); + llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.repack_tensors, params.kv_overrides); model.hparams.vocab_only = params.vocab_only; @@ -16333,7 +16354,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides; kv_overrides = v->data(); } - llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides); + llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, kv_overrides); ml.init_mappings(false); // no prefetching llama_model model; @@ -17007,6 +17028,7 @@ struct llama_model_params llama_model_default_params() { /*.use_mmap =*/ true, /*.use_mlock =*/ false, /*.check_tensors =*/ false, + /*.repack_tensors =*/ false, }; #ifdef GGML_USE_METAL |