summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2024-12-17 14:16:34 +0100
committerGitHub <noreply@github.com>2024-12-17 14:16:34 +0100
commit514ae086200a8cfd78af6a71b6c6ee14931ddc0e (patch)
tree0fa47186d7c82afbf078d530f5436c7eb1ae4d79 /examples
parent4ade4c568c331acad22537f7b9519c740c7a06d0 (diff)
Be able to repack tensors at run time (#147)
* Be able to repack tensors at run time * Repack: also add bf16 as repackable type * Repack: make sure number of rows is a multiple of the packing --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'examples')
-rw-r--r--examples/llama-bench/llama-bench.cpp32
1 files changed, 29 insertions, 3 deletions
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 9e4fd266..7741a227 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -238,6 +238,7 @@ struct cmd_params {
int reps;
bool verbose;
bool warmup;
+ bool repack;
output_formats output_format;
output_formats output_format_stderr;
};
@@ -265,6 +266,7 @@ static const cmd_params cmd_params_defaults = {
/* reps */ 5,
/* verbose */ false,
/* warmup */ true,
+ /* repack */ false,
/* output_format */ MARKDOWN,
/* output_format_stderr */ NONE,
};
@@ -298,6 +300,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -oe, --output-err <csv|json|md|sql> (default: %s)\n", output_format_str(cmd_params_defaults.output_format_stderr));
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0");
+ printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0");
printf("\n");
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
}
@@ -571,6 +574,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break;
}
params.warmup = std::stoi(argv[i]);
+ } else if (arg == "-rtr" || arg == "--run-time-repack") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.repack = std::stoi(argv[i]);
} else {
invalid_param = true;
break;
@@ -623,6 +632,7 @@ struct cmd_params_instance {
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
+ bool repack;
llama_model_params to_llama_mparams() const {
llama_model_params mparams = llama_model_default_params();
@@ -635,6 +645,7 @@ struct cmd_params_instance {
mparams.main_gpu = main_gpu;
mparams.tensor_split = tensor_split.data();
mparams.use_mmap = use_mmap;
+ mparams.repack_tensors = repack;
return mparams;
}
@@ -646,6 +657,7 @@ struct cmd_params_instance {
split_mode == other.split_mode &&
main_gpu == other.main_gpu &&
use_mmap == other.use_mmap &&
+ repack == other.repack &&
tensor_split == other.tensor_split;
}
@@ -706,6 +718,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
+ /* .repack = */ params.repack,
};
instances.push_back(instance);
}
@@ -732,6 +745,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
+ /* .repack = */ params.repack,
};
instances.push_back(instance);
}
@@ -758,6 +772,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
+ /* .repack = */ params.repack,
};
instances.push_back(instance);
}
@@ -796,6 +811,7 @@ struct test {
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
+ bool repack;
int n_prompt;
int n_gen;
std::string test_time;
@@ -822,6 +838,7 @@ struct test {
tensor_split = inst.tensor_split;
use_mmap = inst.use_mmap;
embeddings = inst.embeddings;
+ repack = inst.repack;
n_prompt = inst.n_prompt;
n_gen = inst.n_gen;
// RFC 3339 date-time format
@@ -891,7 +908,7 @@ struct test {
"n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
"main_gpu", "no_kv_offload", "flash_attn",
- "tensor_split", "use_mmap", "embeddings",
+ "tensor_split", "use_mmap", "embeddings", "repack",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
"avg_ts", "stddev_ts"
@@ -912,7 +929,7 @@ struct test {
}
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
- field == "flash_attn" || field == "use_mmap" || field == "embeddings") {
+ field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
@@ -947,7 +964,7 @@ struct test {
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), split_mode_str(split_mode),
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
- tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
+ tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
std::to_string(avg_ts()), std::to_string(stdev_ts())
@@ -1112,6 +1129,9 @@ struct markdown_printer : public printer {
if (field == "use_mmap") {
return 4;
}
+ if (field == "repack") {
+ return 3;
+ }
if (field == "test") {
return 13;
}
@@ -1143,6 +1163,9 @@ struct markdown_printer : public printer {
if (field == "use_mmap") {
return "mmap";
}
+ if (field == "repack") {
+ return "rtr";
+ }
if (field == "embeddings") {
return "embd";
}
@@ -1198,6 +1221,9 @@ struct markdown_printer : public printer {
if (params.embeddings.size() > 1 || params.embeddings != cmd_params_defaults.embeddings) {
fields.emplace_back("embeddings");
}
+ if (params.repack != cmd_params_defaults.repack) {
+ fields.emplace_back("repack");
+ }
fields.emplace_back("test");
fields.emplace_back("t/s");