summaryrefslogtreecommitdiff
path: root/common/common.cpp
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-03-22 15:33:38 +0200
committerGitHub <noreply@github.com>2024-03-22 15:33:38 +0200
commit80bd33bc2c4be352697dc8473339f25e1085d117 (patch)
treeaada7156008e4ad7fb0be8c6182e5d97f175b201 /common/common.cpp
parente80f06d2a194be62ab5b1cd7ef7c7a5b241dd4fb (diff)
common : add HF arg helpers (#6234)
* common : add HF arg helpers * common : remove defaults
Diffstat (limited to 'common/common.cpp')
-rw-r--r--common/common.cpp85
1 files changed, 72 insertions, 13 deletions
diff --git a/common/common.cpp b/common/common.cpp
index cc230c9f..0cc4859f 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -647,6 +647,22 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg,
params.model = argv[i];
return true;
}
+ if (arg == "-md" || arg == "--model-draft") {
+ if (++i >= argc) {
+ invalid_param = true;
+ return true;
+ }
+ params.model_draft = argv[i];
+ return true;
+ }
+ if (arg == "-a" || arg == "--alias") {
+ if (++i >= argc) {
+ invalid_param = true;
+ return true;
+ }
+ params.model_alias = argv[i];
+ return true;
+ }
if (arg == "-mu" || arg == "--model-url") {
if (++i >= argc) {
invalid_param = true;
@@ -655,20 +671,20 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg,
params.model_url = argv[i];
return true;
}
- if (arg == "-md" || arg == "--model-draft") {
+ if (arg == "-hfr" || arg == "--hf-repo") {
if (++i >= argc) {
invalid_param = true;
return true;
}
- params.model_draft = argv[i];
+ params.hf_repo = argv[i];
return true;
}
- if (arg == "-a" || arg == "--alias") {
+ if (arg == "-hff" || arg == "--hf-file") {
if (++i >= argc) {
invalid_param = true;
return true;
}
- params.model_alias = argv[i];
+ params.hf_file = argv[i];
return true;
}
if (arg == "--lora") {
@@ -1403,10 +1419,14 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" layer range to apply the control vector(s) to, start and end inclusive\n");
printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str());
- printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
- printf(" model download url (default: %s)\n", params.model_url.c_str());
printf(" -md FNAME, --model-draft FNAME\n");
- printf(" draft model for speculative decoding\n");
+ printf(" draft model for speculative decoding (default: unused)\n");
+ printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
+ printf(" model download url (default: unused)\n");
+ printf(" -hfr REPO, --hf-repo REPO\n");
+ printf(" Hugging Face model repository (default: unused)\n");
+ printf(" -hff FILE, --hf-file FILE\n");
+ printf(" Hugging Face model file (default: unused)\n");
printf(" -ld LOGDIR, --logdir LOGDIR\n");
printf(" path under which to save YAML logs (no logging if unset)\n");
printf(" --override-kv KEY=TYPE:VALUE\n");
@@ -1655,8 +1675,10 @@ void llama_batch_add(
#ifdef LLAMA_USE_CURL
-struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model,
- struct llama_model_params params) {
+struct llama_model * llama_load_model_from_url(
+ const char * model_url,
+ const char * path_model,
+ const struct llama_model_params & params) {
// Basic validation of the model_url
if (!model_url || strlen(model_url) == 0) {
fprintf(stderr, "%s: invalid model_url\n", __func__);
@@ -1850,25 +1872,62 @@ struct llama_model * llama_load_model_from_url(const char * model_url, const cha
return llama_load_model_from_file(path_model, params);
}
+struct llama_model * llama_load_model_from_hf(
+ const char * repo,
+ const char * model,
+ const char * path_model,
+ const struct llama_model_params & params) {
+ // construct hugging face model url:
+ //
+ // --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
+ // https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
+ //
+ // --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
+ // https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
+ //
+
+ std::string model_url = "https://huggingface.co/";
+ model_url += repo;
+ model_url += "/resolve/main/";
+ model_url += model;
+
+ return llama_load_model_from_url(model_url.c_str(), path_model, params);
+}
+
#else
-struct llama_model * llama_load_model_from_url(const char * /*model_url*/, const char * /*path_model*/,
- struct llama_model_params /*params*/) {
+struct llama_model * llama_load_model_from_url(
+ const char * /*model_url*/,
+ const char * /*path_model*/,
+ const struct llama_model_params & /*params*/) {
fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
return nullptr;
}
+struct llama_model * llama_load_model_from_hf(
+ const char * /*repo*/,
+ const char * /*model*/,
+ const char * /*path_model*/,
+ const struct llama_model_params & /*params*/) {
+ fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
+ return nullptr;
+}
+
#endif // LLAMA_USE_CURL
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto mparams = llama_model_params_from_gpt_params(params);
llama_model * model = nullptr;
- if (!params.model_url.empty()) {
+
+ if (!params.hf_repo.empty() && !params.hf_file.empty()) {
+ model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams);
+ } else if (!params.model_url.empty()) {
model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams);
} else {
model = llama_load_model_from_file(params.model.c_str(), mparams);
}
+
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return std::make_tuple(nullptr, nullptr);
@@ -1908,7 +1967,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
}
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
- const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]);
+ const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
float lora_scale = std::get<1>(params.lora_adapter[i]);
int err = llama_model_apply_lora_from_file(model,
lora_adapter.c_str(),