summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/common.cpp8
-rw-r--r--common/common.h1
-rw-r--r--examples/llama-bench/llama-bench.cpp32
-rw-r--r--ggml/src/iqk/iqk_quantize.cpp79
-rw-r--r--ggml/src/iqk/iqk_quantize.h2
-rw-r--r--include/llama.h1
-rw-r--r--src/CMakeLists.txt1
-rw-r--r--src/llama.cpp28
8 files changed, 146 insertions, 6 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 75dd78e6..95e91bc1 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -906,6 +906,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.use_mmap = false;
return true;
}
+ if (arg == "-rtr" || arg == "--run-time-repack") {
+ params.repack_tensors = true;
+ params.use_mmap = false;
+ return true;
+ }
if (arg == "--numa") {
CHECK_ARG
std::string value(argv[i]);
@@ -1579,6 +1584,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
if (llama_supports_mmap()) {
options.push_back({ "*", " --no-mmap", "do not memory-map model (slower load but may reduce pageouts if not using mlock)" });
}
+ options.push_back({ "*", " --run-time-repack", "repack tensors if interleaved variant is available"});
options.push_back({ "*", " --numa TYPE", "attempt optimizations that help on some NUMA systems\n"
" - distribute: spread execution evenly over all nodes\n"
" - isolate: only spawn threads on CPUs on the node that execution started on\n"
@@ -2204,6 +2210,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
+ mparams.repack_tensors = params.repack_tensors;
if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL;
} else {
@@ -3244,6 +3251,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict);
fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs);
fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false");
+ fprintf(stream, "repack: %s # default: false\n", params.repack_tensors ? "true" : "false");
fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false");
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
diff --git a/common/common.h b/common/common.h
index 486017ef..73d7d650 100644
--- a/common/common.h
+++ b/common/common.h
@@ -187,6 +187,7 @@ struct gpt_params {
bool no_kv_offload = false; // disable KV offloading
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data
+ bool repack_tensors = false; // repack tensors if interleaved variant is available
std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 9e4fd266..7741a227 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -238,6 +238,7 @@ struct cmd_params {
int reps;
bool verbose;
bool warmup;
+ bool repack;
output_formats output_format;
output_formats output_format_stderr;
};
@@ -265,6 +266,7 @@ static const cmd_params cmd_params_defaults = {
/* reps */ 5,
/* verbose */ false,
/* warmup */ true,
+ /* repack */ false,
/* output_format */ MARKDOWN,
/* output_format_stderr */ NONE,
};
@@ -298,6 +300,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -oe, --output-err <csv|json|md|sql> (default: %s)\n", output_format_str(cmd_params_defaults.output_format_stderr));
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0");
+ printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0");
printf("\n");
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
}
@@ -571,6 +574,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break;
}
params.warmup = std::stoi(argv[i]);
+ } else if (arg == "-rtr" || arg == "--run-time-repack") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.repack = std::stoi(argv[i]);
} else {
invalid_param = true;
break;
@@ -623,6 +632,7 @@ struct cmd_params_instance {
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
+ bool repack;
llama_model_params to_llama_mparams() const {
llama_model_params mparams = llama_model_default_params();
@@ -635,6 +645,7 @@ struct cmd_params_instance {
mparams.main_gpu = main_gpu;
mparams.tensor_split = tensor_split.data();
mparams.use_mmap = use_mmap;
+ mparams.repack_tensors = repack;
return mparams;
}
@@ -646,6 +657,7 @@ struct cmd_params_instance {
split_mode == other.split_mode &&
main_gpu == other.main_gpu &&
use_mmap == other.use_mmap &&
+ repack == other.repack &&
tensor_split == other.tensor_split;
}
@@ -706,6 +718,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
+ /* .repack = */ params.repack,
};
instances.push_back(instance);
}
@@ -732,6 +745,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
+ /* .repack = */ params.repack,
};
instances.push_back(instance);
}
@@ -758,6 +772,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
+ /* .repack = */ params.repack,
};
instances.push_back(instance);
}
@@ -796,6 +811,7 @@ struct test {
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
+ bool repack;
int n_prompt;
int n_gen;
std::string test_time;
@@ -822,6 +838,7 @@ struct test {
tensor_split = inst.tensor_split;
use_mmap = inst.use_mmap;
embeddings = inst.embeddings;
+ repack = inst.repack;
n_prompt = inst.n_prompt;
n_gen = inst.n_gen;
// RFC 3339 date-time format
@@ -891,7 +908,7 @@ struct test {
"n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
"main_gpu", "no_kv_offload", "flash_attn",
- "tensor_split", "use_mmap", "embeddings",
+ "tensor_split", "use_mmap", "embeddings", "repack",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
"avg_ts", "stddev_ts"
@@ -912,7 +929,7 @@ struct test {
}
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
- field == "flash_attn" || field == "use_mmap" || field == "embeddings") {
+ field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
@@ -947,7 +964,7 @@ struct test {
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), split_mode_str(split_mode),
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
- tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
+ tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
std::to_string(avg_ts()), std::to_string(stdev_ts())
@@ -1112,6 +1129,9 @@ struct markdown_printer : public printer {
if (field == "use_mmap") {
return 4;
}
+ if (field == "repack") {
+ return 3;
+ }
if (field == "test") {
return 13;
}
@@ -1143,6 +1163,9 @@ struct markdown_printer : public printer {
if (field == "use_mmap") {
return "mmap";
}
+ if (field == "repack") {
+ return "rtr";
+ }
if (field == "embeddings") {
return "embd";
}
@@ -1198,6 +1221,9 @@ struct markdown_printer : public printer {
if (params.embeddings.size() > 1 || params.embeddings != cmd_params_defaults.embeddings) {
fields.emplace_back("embeddings");
}
+ if (params.repack != cmd_params_defaults.repack) {
+ fields.emplace_back("repack");
+ }
fields.emplace_back("test");
fields.emplace_back("t/s");
diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp
index 3077fe21..3408d054 100644
--- a/ggml/src/iqk/iqk_quantize.cpp
+++ b/ggml/src/iqk/iqk_quantize.cpp
@@ -21,6 +21,9 @@
#include <algorithm>
#include <cstring>
#include <mutex>
+#include <thread>
+#include <atomic>
+#include <unordered_map>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
@@ -5054,3 +5057,79 @@ void vec_dot_iq2_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t
GGML_UNUSED(by);
}
+namespace {
+struct Repack {
+ using repack_func = void (*) (int nrows, int n_per_row, const char * src, char * dst);
+ ggml_type new_type;
+ int num_rows;
+ repack_func repack;
+};
+}
+
+void iqk_repack_tensor(struct ggml_tensor * tensor) {
+ constexpr int kChunk = 8;
+ if (!tensor) return;
+ if (!ggml_is_contiguous(tensor)) return;
+ if (strncmp(tensor->name, "token_embd.weight", GGML_MAX_NAME) == 0) return;
+ if (tensor->ne[1] % 4 || tensor->ne[2]*tensor->ne[3] > 1) return;
+ static const std::unordered_map<ggml_type, Repack> k_map = {
+ { GGML_TYPE_IQ2_K, { GGML_TYPE_IQ2_K_R4, 4, (Repack::repack_func)repack_iq2_k} },
+ { GGML_TYPE_IQ3_K, { GGML_TYPE_IQ3_K_R4, 4, (Repack::repack_func)repack_iq3_k} },
+ { GGML_TYPE_IQ4_K, { GGML_TYPE_IQ4_K_R4, 4, (Repack::repack_func)repack_iq4_k} },
+ { GGML_TYPE_IQ4_XS, { GGML_TYPE_IQ4_XS_R4, 4, (Repack::repack_func)repack_iq4_xs} },
+ { GGML_TYPE_IQ4_NL, { GGML_TYPE_IQ4_NL_R4, 4, (Repack::repack_func)repack_iq4_nl} },
+ { GGML_TYPE_IQ2_BN, { GGML_TYPE_IQ2_BN_R4, 4, (Repack::repack_func)repack_iq2_bn} },
+ { GGML_TYPE_Q2_K, { GGML_TYPE_Q2_K_R4, 4, (Repack::repack_func)repack_q2_k} },
+ { GGML_TYPE_Q3_K, { GGML_TYPE_Q3_K_R4, 4, (Repack::repack_func)repack_q3_k} },
+ { GGML_TYPE_Q4_K, { GGML_TYPE_Q4_K_R4, 4, (Repack::repack_func)repack_q4_k} },
+ { GGML_TYPE_Q5_K, { GGML_TYPE_Q5_K_R4, 4, (Repack::repack_func)repack_q5_k} },
+ { GGML_TYPE_Q6_K, { GGML_TYPE_Q6_K_R4, 4, (Repack::repack_func)repack_q6_k} },
+ { GGML_TYPE_Q4_0, { GGML_TYPE_Q4_0_R4, 4, (Repack::repack_func)repack_q4_0} },
+ { GGML_TYPE_Q5_0, { GGML_TYPE_Q5_0_R4, 4, (Repack::repack_func)repack_q5_0} },
+ { GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} },
+ { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R4, 4, (Repack::repack_func)repack_q8_0} },
+ { GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} },
+#ifdef __AVX512BF16__
+ { GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16<ggml_bf16_t>} },
+#endif
+ };
+
+ auto it = k_map.find(tensor->type);
+ if (it == k_map.end()) return;
+ if (tensor->ne[1] % it->second.num_rows) return;
+
+ auto& r = it->second;
+
+ int max_thread = std::max(1, int(std::thread::hardware_concurrency()/2));
+ int num_chunks = (tensor->ne[1] + kChunk*r.num_rows - 1)/(kChunk*r.num_rows);
+ int nthread = std::min(num_chunks, max_thread);
+
+ //printf("%s(%s): %s -> %s. %d rows, %d chunks, %d threads\n", __func__, tensor->name, ggml_type_name(tensor->type), ggml_type_name(r.new_type),
+ // int(tensor->ne[1]), num_chunks, nthread);
+
+ std::atomic<int> counter(0);;
+ auto compute = [&counter, &r, tensor, num_chunks] () {
+ int nrows = tensor->ne[1];
+ int n_per_row = tensor->ne[0];
+ auto row_size = ggml_row_size(tensor->type, n_per_row);
+ std::vector<char> qtmp(r.num_rows*row_size);
+ auto data = (char *)tensor->data;
+ while (true) {
+ int chunk = counter.fetch_add(1);
+ if (chunk >= num_chunks) break;
+ int first_row = chunk*kChunk*r.num_rows;
+ int last_row = std::min(first_row + kChunk*r.num_rows, nrows);
+ for (int row = first_row; row < last_row; row += r.num_rows) {
+ std::memcpy(qtmp.data(), data + row*row_size, r.num_rows*row_size);
+ r.repack(r.num_rows, n_per_row, qtmp.data(), data + row*row_size);
+ }
+ }
+ };
+ std::vector<std::thread> workers(nthread-1);
+ for (auto& w : workers) w = std::thread(compute);
+ compute();
+ for (auto& w : workers) w.join();
+
+ tensor->type = r.new_type;
+}
+
diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h
index 8640b59b..7c568ded 100644
--- a/ggml/src/iqk/iqk_quantize.h
+++ b/ggml/src/iqk/iqk_quantize.h
@@ -173,6 +173,8 @@ void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
void repack_f32_bf16_r16 (const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row);
void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row);
+void iqk_repack_tensor(struct ggml_tensor * tensor);
+
#ifdef __cplusplus
}
#endif
diff --git a/include/llama.h b/include/llama.h
index 1627a752..e63d76fe 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -325,6 +325,7 @@ extern "C" {
bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM
bool check_tensors; // validate model tensor data
+ bool repack_tensors;// repack if available
};
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 46a6ad56..3aea013a 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -23,6 +23,7 @@ add_library(llama
)
target_include_directories(llama PUBLIC . ../include)
+target_include_directories(llama PRIVATE ../ggml/src)
target_compile_features (llama PUBLIC cxx_std_11) # don't bump
target_link_libraries(llama PUBLIC ggml)
diff --git a/src/llama.cpp b/src/llama.cpp
index 62ab4d08..68e59758 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -9,6 +9,9 @@
#include "ggml-alloc.h"
#include "ggml-backend.h"
+// TODO: fix this include
+#include "iqk/iqk_quantize.h"
+
#ifdef GGML_USE_RPC
# include "ggml-rpc.h"
#endif
@@ -3653,6 +3656,7 @@ struct llama_model_loader {
bool use_mmap = false;
bool check_tensors;
+ bool repack_tensors = false;
llama_files files;
llama_ftype ftype;
@@ -3686,7 +3690,7 @@ struct llama_model_loader {
std::string arch_name;
LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
- llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
+ llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, const struct llama_model_kv_override * param_overrides_p) {
int trace = 0;
if (getenv("LLAMA_TRACE")) {
trace = atoi(getenv("LLAMA_TRACE"));
@@ -3928,9 +3932,13 @@ struct llama_model_loader {
LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__);
use_mmap = false;
}
+ if (repack_tensors) {
+ use_mmap = false;
+ }
this->use_mmap = use_mmap;
this->check_tensors = check_tensors;
+ this->repack_tensors = repack_tensors;
}
~llama_model_loader() {
@@ -7880,6 +7888,19 @@ static bool llm_load_tensors(
}
}
+ if (!ml.use_mmap && ml.repack_tensors) {
+ int n_repacked = 0;
+ for (auto& it : model.tensors_by_name) {
+ if (ggml_backend_buffer_is_host(it.second->buffer)) {
+ auto orig_type = it.second->type;
+ iqk_repack_tensor(it.second);
+ if (it.second->type != orig_type) ++n_repacked;
+ //printf("Repacking tensor %s\n", it.first.c_str());
+ }
+ }
+ if (n_repacked > 0) printf("============ Repacked %d tensors\n", n_repacked);
+ }
+
if (model.arch == LLM_ARCH_BITNET) {
auto set_scale = [] (ggml_tensor * w, ggml_tensor * s) {
if (!s) {
@@ -7915,7 +7936,7 @@ static bool llm_load_tensors(
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
try {
- llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
+ llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.repack_tensors, params.kv_overrides);
model.hparams.vocab_only = params.vocab_only;
@@ -16333,7 +16354,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
kv_overrides = v->data();
}
- llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
+ llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, kv_overrides);
ml.init_mappings(false); // no prefetching
llama_model model;
@@ -17007,6 +17028,7 @@ struct llama_model_params llama_model_default_params() {
/*.use_mmap =*/ true,
/*.use_mlock =*/ false,
/*.check_tensors =*/ false,
+ /*.repack_tensors =*/ false,
};
#ifdef GGML_USE_METAL