summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/llama-bench/llama-bench.cpp53
1 files changed, 53 insertions, 0 deletions
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index b0790e20..438d2a7c 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -236,6 +236,7 @@ struct cmd_params {
std::vector<std::vector<float>> tensor_split;
std::vector<bool> use_mmap;
std::vector<bool> embeddings;
+ std::vector<llama_model_tensor_buft_override> buft_overrides;
ggml_numa_strategy numa;
int reps;
bool verbose;
@@ -267,6 +268,7 @@ static const cmd_params cmd_params_defaults = {
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* use_mmap */ {true},
/* embeddings */ {false},
+ /* buft_overrides */ {},
/* numa */ GGML_NUMA_STRATEGY_DISABLED,
/* reps */ 5,
/* verbose */ false,
@@ -309,6 +311,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0");
printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0");
+ printf(" -ot, --override-tensor pattern (default: none)\n");
printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0");
printf("\n");
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
@@ -349,6 +352,39 @@ static ggml_type ggml_type_from_name(const std::string & s) {
return GGML_TYPE_COUNT;
}
+namespace {
+bool parse_buft_overrides(const std::string& value, std::vector<llama_model_tensor_buft_override>& overrides) {
+ /* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
+ if (buft_list.empty()) {
+ // enumerate all the devices and add their buffer types to the list
+ for (size_t i = 0; i < ggml_backend_reg_get_count(); ++i) {
+ //auto * dev = ggml_backend_reg_get_name(i);
+ auto * buft = ggml_backend_reg_get_default_buffer_type(i);
+ if (buft) {
+ buft_list[ggml_backend_buft_name(buft)] = buft;
+ }
+ }
+ }
+ for (const auto & override : string_split<std::string>(value, ',')) {
+ std::string::size_type pos = override.find('=');
+ if (pos == std::string::npos) {
+ fprintf(stderr, "Invalid buft override argument %s\n", value.c_str());
+ return false;
+ }
+ std::string tensor_name = override.substr(0, pos);
+ std::string buffer_type = override.substr(pos + 1);
+ if (buft_list.find(buffer_type) == buft_list.end()) {
+ fprintf(stderr, "Available buffer types:\n");
+ for (const auto & it : buft_list) {
+ fprintf(stderr, " %s\n", ggml_backend_buft_name(it.second));
+ }
+ return false;
+ }
+ overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)});
+ }
+ return true;
+}
+}
static cmd_params parse_cmd_params(int argc, char ** argv) {
cmd_params params;
@@ -616,6 +652,16 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break;
}
params.fmoe = std::stoi(argv[i]);
+ } else if (arg == "-ot" || arg == "--override-tensor") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ if (!parse_buft_overrides(std::string{argv[i]}, params.buft_overrides)) {
+ fprintf(stderr, "error: Invalid tensor buffer type override: %s\n", argv[i]);
+ invalid_param = true;
+ break;
+ }
} else {
invalid_param = true;
break;
@@ -648,6 +694,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; }
+ if (!params.buft_overrides.empty()) params.buft_overrides.emplace_back(llama_model_tensor_buft_override{nullptr, nullptr});
return params;
}
@@ -685,6 +732,7 @@ struct cmd_params_instance {
bool embeddings;
bool repack = false;
bool fmoe = false;
+ const llama_model_tensor_buft_override* buft_overrides;
llama_model_params to_llama_mparams() const {
llama_model_params mparams = llama_model_default_params();
@@ -698,6 +746,7 @@ struct cmd_params_instance {
mparams.tensor_split = tensor_split.data();
mparams.use_mmap = use_mmap;
mparams.repack_tensors = repack;
+ mparams.tensor_buft_overrides = buft_overrides;
return mparams;
}
@@ -777,6 +826,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd,
/* .repack = */ params.repack,
/* .fmoe = */ params.fmoe,
+ /* .buft_overrides=*/ params.buft_overrides.data(),
};
instances.push_back(instance);
}
@@ -807,6 +857,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd,
/* .repack = */ params.repack,
/* .fmoe = */ params.fmoe,
+ /* .buft_overrides=*/ params.buft_overrides.data(),
};
instances.push_back(instance);
}
@@ -837,6 +888,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd,
/* .repack = */ params.repack,
/* .fmoe = */ params.fmoe,
+ /* .buft_overrides=*/ params.buft_overrides.data(),
};
instances.push_back(instance);
}
@@ -867,6 +919,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd,
/* .repack = */ params.repack,
/* .fmoe = */ params.fmoe,
+ /* .buft_overrides=*/ params.buft_overrides.data(),
};
instances.push_back(instance);
}