summaryrefslogtreecommitdiff
path: root/common/common.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'common/common.cpp')
-rw-r--r--common/common.cpp51
1 files changed, 51 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp
index f975aee3..464b4710 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -265,6 +265,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.kv_overrides.emplace_back();
params.kv_overrides.back().key[0] = 0;
}
+ if (!params.tensor_buft_overrides.empty()) {
+ params.tensor_buft_overrides.push_back({nullptr, nullptr});
+ }
return true;
}
@@ -287,6 +290,40 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
return true;
}
+namespace {
+bool parse_buft_overrides(const std::string& value, std::vector<llama_model_tensor_buft_override>& overrides) {
+ /* static */ std::map<std::string, ggml_backend_buffer_type_t> buft_list;
+ if (buft_list.empty()) {
+ // enumerate all the devices and add their buffer types to the list
+ for (size_t i = 0; i < ggml_backend_reg_get_count(); ++i) {
+ //auto * dev = ggml_backend_reg_get_name(i);
+ auto * buft = ggml_backend_reg_get_default_buffer_type(i);
+ if (buft) {
+ buft_list[ggml_backend_buft_name(buft)] = buft;
+ }
+ }
+ }
+ for (const auto & override : string_split<std::string>(value, ',')) {
+ std::string::size_type pos = override.find('=');
+ if (pos == std::string::npos) {
+ fprintf(stderr, "Invalid buft override argument %s\n", value.c_str());
+ return false;
+ }
+ std::string tensor_name = override.substr(0, pos);
+ std::string buffer_type = override.substr(pos + 1);
+ if (buft_list.find(buffer_type) == buft_list.end()) {
+ fprintf(stderr, "Available buffer types:\n");
+ for (const auto & it : buft_list) {
+ fprintf(stderr, " %s\n", ggml_backend_buft_name(it.second));
+ }
+ return false;
+ }
+ overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)});
+ }
+ return true;
+}
+}
+
#define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; }
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
@@ -1120,6 +1157,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
return true;
}
+ if (arg == "--override-tensor" || arg == "-ot") {
+ CHECK_ARG
+ if (!parse_buft_overrides(std::string{argv[i]}, params.tensor_buft_overrides)) {
+ fprintf(stderr, "error: Invalid tensor buffer type override: %s\n", argv[i]);
+ invalid_param = true;
+ }
+ return true;
+ }
if (arg == "--host") {
CHECK_ARG
params.hostname = argv[i];
@@ -2238,6 +2283,12 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
mparams.kv_overrides = params.kv_overrides.data();
}
+ if (params.tensor_buft_overrides.empty()) {
+ mparams.tensor_buft_overrides = NULL;
+ } else {
+ GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
+ mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
+ }
return mparams;
}