diff options
-rw-r--r-- | common/common.cpp | 8 | ||||
-rw-r--r-- | common/common.h | 1 | ||||
-rw-r--r-- | examples/llama-bench/llama-bench.cpp | 32 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.cpp | 79 | ||||
-rw-r--r-- | ggml/src/iqk/iqk_quantize.h | 2 | ||||
-rw-r--r-- | include/llama.h | 1 | ||||
-rw-r--r-- | src/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/llama.cpp | 28 |
8 files changed, 146 insertions, 6 deletions
diff --git a/common/common.cpp b/common/common.cpp index 75dd78e6..95e91bc1 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -906,6 +906,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.use_mmap = false; return true; } + if (arg == "-rtr" || arg == "--run-time-repack") { + params.repack_tensors = true; + params.use_mmap = false; + return true; + } if (arg == "--numa") { CHECK_ARG std::string value(argv[i]); @@ -1579,6 +1584,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param if (llama_supports_mmap()) { options.push_back({ "*", " --no-mmap", "do not memory-map model (slower load but may reduce pageouts if not using mlock)" }); } + options.push_back({ "*", " --run-time-repack", "repack tensors if interleaved variant is available"}); options.push_back({ "*", " --numa TYPE", "attempt optimizations that help on some NUMA systems\n" " - distribute: spread execution evenly over all nodes\n" " - isolate: only spawn threads on CPUs on the node that execution started on\n" @@ -2204,6 +2210,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & mparams.use_mmap = params.use_mmap; mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; + mparams.repack_tensors = params.repack_tensors; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; } else { @@ -3244,6 +3251,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); + fprintf(stream, "repack: %s # default: false\n", params.repack_tensors ? "true" : "false"); fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false"); fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); diff --git a/common/common.h b/common/common.h index 486017ef..73d7d650 100644 --- a/common/common.h +++ b/common/common.h @@ -187,6 +187,7 @@ struct gpt_params { bool no_kv_offload = false; // disable KV offloading bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data + bool repack_tensors = false; // repack tensors if interleaved variant is available std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 9e4fd266..7741a227 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -238,6 +238,7 @@ struct cmd_params { int reps; bool verbose; bool warmup; + bool repack; output_formats output_format; output_formats output_format_stderr; }; @@ -265,6 +266,7 @@ static const cmd_params cmd_params_defaults = { /* reps */ 5, /* verbose */ false, /* warmup */ true, + /* repack */ false, /* output_format */ MARKDOWN, /* output_format_stderr */ NONE, }; @@ -298,6 +300,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -oe, --output-err <csv|json|md|sql> (default: %s)\n", output_format_str(cmd_params_defaults.output_format_stderr)); printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0"); + printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0"); printf("\n"); printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); } @@ -571,6 +574,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.warmup = std::stoi(argv[i]); + } else if (arg == "-rtr" || arg == "--run-time-repack") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.repack = std::stoi(argv[i]); } else { invalid_param = true; break; @@ -623,6 +632,7 @@ struct cmd_params_instance { std::vector<float> tensor_split; bool use_mmap; bool embeddings; + bool repack; llama_model_params to_llama_mparams() const { llama_model_params mparams = llama_model_default_params(); @@ -635,6 +645,7 @@ struct cmd_params_instance { mparams.main_gpu = main_gpu; mparams.tensor_split = tensor_split.data(); mparams.use_mmap = use_mmap; + mparams.repack_tensors = repack; return mparams; } @@ -646,6 +657,7 @@ struct cmd_params_instance { split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap && + repack == other.repack && tensor_split == other.tensor_split; } @@ -706,6 +718,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, + /* .repack = */ params.repack, }; instances.push_back(instance); } @@ -732,6 +745,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, + /* .repack = */ params.repack, }; instances.push_back(instance); } @@ -758,6 +772,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, + /* .repack = */ params.repack, }; instances.push_back(instance); } @@ -796,6 +811,7 @@ struct test { std::vector<float> tensor_split; bool use_mmap; bool embeddings; + bool repack; int n_prompt; int n_gen; std::string test_time; @@ -822,6 +838,7 @@ struct test { tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; + repack = inst.repack; n_prompt = inst.n_prompt; n_gen = inst.n_gen; // RFC 3339 date-time format @@ -891,7 +908,7 @@ struct test { "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", - "tensor_split", "use_mmap", "embeddings", + "tensor_split", "use_mmap", "embeddings", "repack", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts" @@ -912,7 +929,7 @@ struct test { } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "flash_attn" || field == "use_mmap" || field == "embeddings") { + field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -947,7 +964,7 @@ struct test { std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), - tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), + tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), std::to_string(avg_ts()), std::to_string(stdev_ts()) @@ -1112,6 +1129,9 @@ struct markdown_printer : public printer { if (field == "use_mmap") { return 4; } + if (field == "repack") { + return 3; + } if (field == "test") { return 13; } @@ -1143,6 +1163,9 @@ struct markdown_printer : public printer { if (field == "use_mmap") { return "mmap"; } + if (field == "repack") { + return "rtr"; + } if (field == "embeddings") { return "embd"; } @@ -1198,6 +1221,9 @@ struct markdown_printer : public printer { if (params.embeddings.size() > 1 || params.embeddings != cmd_params_defaults.embeddings) { fields.emplace_back("embeddings"); } + if (params.repack != cmd_params_defaults.repack) { + fields.emplace_back("repack"); + } fields.emplace_back("test"); fields.emplace_back("t/s"); diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 3077fe21..3408d054 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -21,6 +21,9 @@ #include <algorithm> #include <cstring> #include <mutex> +#include <thread> +#include <atomic> +#include <unordered_map> #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -5054,3 +5057,79 @@ void vec_dot_iq2_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t GGML_UNUSED(by); } +namespace { +struct Repack { + using repack_func = void (*) (int nrows, int n_per_row, const char * src, char * dst); + ggml_type new_type; + int num_rows; + repack_func repack; +}; +} + +void iqk_repack_tensor(struct ggml_tensor * tensor) { + constexpr int kChunk = 8; + if (!tensor) return; + if (!ggml_is_contiguous(tensor)) return; + if (strncmp(tensor->name, "token_embd.weight", GGML_MAX_NAME) == 0) return; + if (tensor->ne[1] % 4 || tensor->ne[2]*tensor->ne[3] > 1) return; + static const std::unordered_map<ggml_type, Repack> k_map = { + { GGML_TYPE_IQ2_K, { GGML_TYPE_IQ2_K_R4, 4, (Repack::repack_func)repack_iq2_k} }, + { GGML_TYPE_IQ3_K, { GGML_TYPE_IQ3_K_R4, 4, (Repack::repack_func)repack_iq3_k} }, + { GGML_TYPE_IQ4_K, { GGML_TYPE_IQ4_K_R4, 4, (Repack::repack_func)repack_iq4_k} }, + { GGML_TYPE_IQ4_XS, { GGML_TYPE_IQ4_XS_R4, 4, (Repack::repack_func)repack_iq4_xs} }, + { GGML_TYPE_IQ4_NL, { GGML_TYPE_IQ4_NL_R4, 4, (Repack::repack_func)repack_iq4_nl} }, + { GGML_TYPE_IQ2_BN, { GGML_TYPE_IQ2_BN_R4, 4, (Repack::repack_func)repack_iq2_bn} }, + { GGML_TYPE_Q2_K, { GGML_TYPE_Q2_K_R4, 4, (Repack::repack_func)repack_q2_k} }, + { GGML_TYPE_Q3_K, { GGML_TYPE_Q3_K_R4, 4, (Repack::repack_func)repack_q3_k} }, + { GGML_TYPE_Q4_K, { GGML_TYPE_Q4_K_R4, 4, (Repack::repack_func)repack_q4_k} }, + { GGML_TYPE_Q5_K, { GGML_TYPE_Q5_K_R4, 4, (Repack::repack_func)repack_q5_k} }, + { GGML_TYPE_Q6_K, { GGML_TYPE_Q6_K_R4, 4, (Repack::repack_func)repack_q6_k} }, + { GGML_TYPE_Q4_0, { GGML_TYPE_Q4_0_R4, 4, (Repack::repack_func)repack_q4_0} }, + { GGML_TYPE_Q5_0, { GGML_TYPE_Q5_0_R4, 4, (Repack::repack_func)repack_q5_0} }, + { GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} }, + { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R4, 4, (Repack::repack_func)repack_q8_0} }, + { GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} }, +#ifdef __AVX512BF16__ + { GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_bf16_t>} }, +#endif + }; + + auto it = k_map.find(tensor->type); + if (it == k_map.end()) return; + if (tensor->ne[1] % it->second.num_rows) return; + + auto& r = it->second; + + int max_thread = std::max(1, int(std::thread::hardware_concurrency()/2)); + int num_chunks = (tensor->ne[1] + kChunk*r.num_rows - 1)/(kChunk*r.num_rows); + int nthread = std::min(num_chunks, max_thread); + + //printf("%s(%s): %s -> %s. %d rows, %d chunks, %d threads\n", __func__, tensor->name, ggml_type_name(tensor->type), ggml_type_name(r.new_type), + // int(tensor->ne[1]), num_chunks, nthread); + + std::atomic<int> counter(0);; + auto compute = [&counter, &r, tensor, num_chunks] () { + int nrows = tensor->ne[1]; + int n_per_row = tensor->ne[0]; + auto row_size = ggml_row_size(tensor->type, n_per_row); + std::vector<char> qtmp(r.num_rows*row_size); + auto data = (char *)tensor->data; + while (true) { + int chunk = counter.fetch_add(1); + if (chunk >= num_chunks) break; + int first_row = chunk*kChunk*r.num_rows; + int last_row = std::min(first_row + kChunk*r.num_rows, nrows); + for (int row = first_row; row < last_row; row += r.num_rows) { + std::memcpy(qtmp.data(), data + row*row_size, r.num_rows*row_size); + r.repack(r.num_rows, n_per_row, qtmp.data(), data + row*row_size); + } + } + }; + std::vector<std::thread> workers(nthread-1); + for (auto& w : workers) w = std::thread(compute); + compute(); + for (auto& w : workers) w.join(); + + tensor->type = r.new_type; +} + diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 8640b59b..7c568ded 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -173,6 +173,8 @@ void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void repack_f32_bf16_r16 (const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); +void iqk_repack_tensor(struct ggml_tensor * tensor); + #ifdef __cplusplus } #endif diff --git a/include/llama.h b/include/llama.h index 1627a752..e63d76fe 100644 --- a/include/llama.h +++ b/include/llama.h @@ -325,6 +325,7 @@ extern "C" { bool use_mmap; // use mmap if possible bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data + bool repack_tensors;// repack if available }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 46a6ad56..3aea013a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -23,6 +23,7 @@ add_library(llama ) target_include_directories(llama PUBLIC . ../include) +target_include_directories(llama PRIVATE ../ggml/src) target_compile_features (llama PUBLIC cxx_std_11) # don't bump target_link_libraries(llama PUBLIC ggml) diff --git a/src/llama.cpp b/src/llama.cpp index 62ab4d08..68e59758 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9,6 +9,9 @@ #include "ggml-alloc.h" #include "ggml-backend.h" +// TODO: fix this include +#include "iqk/iqk_quantize.h" + #ifdef GGML_USE_RPC # include "ggml-rpc.h" #endif @@ -3653,6 +3656,7 @@ struct llama_model_loader { bool use_mmap = false; bool check_tensors; + bool repack_tensors = false; llama_files files; llama_ftype ftype; @@ -3686,7 +3690,7 @@ struct llama_model_loader { std::string arch_name; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); - llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) { + llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, const struct llama_model_kv_override * param_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -3928,9 +3932,13 @@ struct llama_model_loader { LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__); use_mmap = false; } + if (repack_tensors) { + use_mmap = false; + } this->use_mmap = use_mmap; this->check_tensors = check_tensors; + this->repack_tensors = repack_tensors; } ~llama_model_loader() { @@ -7880,6 +7888,19 @@ static bool llm_load_tensors( } } + if (!ml.use_mmap && ml.repack_tensors) { + int n_repacked = 0; + for (auto& it : model.tensors_by_name) { + if (ggml_backend_buffer_is_host(it.second->buffer)) { + auto orig_type = it.second->type; + iqk_repack_tensor(it.second); + if (it.second->type != orig_type) ++n_repacked; + //printf("Repacking tensor %s\n", it.first.c_str()); + } + } + if (n_repacked > 0) printf("============ Repacked %d tensors\n", n_repacked); + } + if (model.arch == LLM_ARCH_BITNET) { auto set_scale = [] (ggml_tensor * w, ggml_tensor * s) { if (!s) { @@ -7915,7 +7936,7 @@ static bool llm_load_tensors( // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { try { - llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides); + llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.repack_tensors, params.kv_overrides); model.hparams.vocab_only = params.vocab_only; @@ -16333,7 +16354,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides; kv_overrides = v->data(); } - llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides); + llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, kv_overrides); ml.init_mappings(false); // no prefetching llama_model model; @@ -17007,6 +17028,7 @@ struct llama_model_params llama_model_default_params() { /*.use_mmap =*/ true, /*.use_mlock =*/ false, /*.check_tensors =*/ false, + /*.repack_tensors =*/ false, }; #ifdef GGML_USE_METAL |