diff options
author | Pierrick Hymbert <pierrick.hymbert@gmail.com> | 2024-03-17 19:12:37 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-17 19:12:37 +0100 |
commit | d01b3c4c32357567f3531d4e6ceffc5d23e87583 (patch) | |
tree | 80e0a075a8b120d6b5b095a73cc36cb2a4535aed /common/common.cpp | |
parent | cd776c37c945bf58efc8fe44b370456680cb1b59 (diff) |
common: llama_load_model_from_url using --model-url (#6098)
* common: llama_load_model_from_url with libcurl dependency
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'common/common.cpp')
-rw-r--r-- | common/common.cpp | 238 |
1 files changed, 237 insertions, 1 deletions
diff --git a/common/common.cpp b/common/common.cpp index 1b0ba849..2f5d965d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -37,6 +37,9 @@ #include <sys/stat.h> #include <unistd.h> #endif +#if defined(LLAMA_USE_CURL) +#include <curl/curl.h> +#endif #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -50,6 +53,18 @@ #define GGML_USE_CUBLAS_SYCL_VULKAN #endif +#if defined(LLAMA_USE_CURL) +#ifdef __linux__ +#include <linux/limits.h> +#elif defined(_WIN32) +#define PATH_MAX MAX_PATH +#else +#include <sys/syslimits.h> +#endif +#define LLAMA_CURL_MAX_PATH_LENGTH PATH_MAX +#define LLAMA_CURL_MAX_HEADER_LENGTH 256 +#endif // LLAMA_USE_CURL + int32_t get_num_physical_cores() { #ifdef __linux__ // enumerate the set of thread siblings, num entries is num cores @@ -644,6 +659,13 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } params.model = argv[i]; } + if (arg == "-mu" || arg == "--model-url") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.model_url = argv[i]; + } if (arg == "-md" || arg == "--model-draft") { arg_found = true; if (++i >= argc) { @@ -1368,6 +1390,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" layer range to apply the control vector(s) to, start and end inclusive\n"); printf(" -m FNAME, --model FNAME\n"); printf(" model path (default: %s)\n", params.model.c_str()); + printf(" -mu MODEL_URL, --model-url MODEL_URL\n"); + printf(" model download url (default: %s)\n", params.model_url.c_str()); printf(" -md FNAME, --model-draft FNAME\n"); printf(" draft model for speculative decoding\n"); printf(" -ld LOGDIR, --logdir LOGDIR\n"); @@ -1613,10 +1637,222 @@ void llama_batch_add( batch.n_tokens++; } +#ifdef LLAMA_USE_CURL + +struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, + struct llama_model_params params) { + // Basic validation of the model_url + if (!model_url || strlen(model_url) == 0) { + fprintf(stderr, "%s: invalid model_url\n", __func__); + return NULL; + } + + // Initialize libcurl globally + auto curl = curl_easy_init(); + + if (!curl) { + fprintf(stderr, "%s: error initializing libcurl\n", __func__); + return NULL; + } + + // Set the URL, allow to follow http redirection + curl_easy_setopt(curl, CURLOPT_URL, model_url); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); +#if defined(_WIN32) + // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of + // operating system. Currently implemented under MS-Windows. + curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + + // Check if the file already exists locally + struct stat model_file_info; + auto file_exists = (stat(path_model, &model_file_info) == 0); + + // If the file exists, check for ${path_model}.etag or ${path_model}.lastModified files + char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; + char etag_path[LLAMA_CURL_MAX_PATH_LENGTH] = {0}; + snprintf(etag_path, sizeof(etag_path), "%s.etag", path_model); + + char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; + char last_modified_path[LLAMA_CURL_MAX_PATH_LENGTH] = {0}; + snprintf(last_modified_path, sizeof(last_modified_path), "%s.lastModified", path_model); + + if (file_exists) { + auto * f_etag = fopen(etag_path, "r"); + if (f_etag) { + if (!fgets(etag, sizeof(etag), f_etag)) { + fprintf(stderr, "%s: unable to read file %s\n", __func__, etag_path); + } else { + fprintf(stderr, "%s: previous model file found %s: %s\n", __func__, etag_path, etag); + } + fclose(f_etag); + } + + auto * f_last_modified = fopen(last_modified_path, "r"); + if (f_last_modified) { + if (!fgets(last_modified, sizeof(last_modified), f_last_modified)) { + fprintf(stderr, "%s: unable to read file %s\n", __func__, last_modified_path); + } else { + fprintf(stderr, "%s: previous model file found %s: %s\n", __func__, last_modified_path, + last_modified); + } + fclose(f_last_modified); + } + } + + // Send a HEAD request to retrieve the etag and last-modified headers + struct llama_load_model_from_url_headers { + char etag[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; + char last_modified[LLAMA_CURL_MAX_HEADER_LENGTH] = {0}; + }; + llama_load_model_from_url_headers headers; + { + typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); + auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { + llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers *) userdata; + + const char * etag_prefix = "etag: "; + if (strncmp(buffer, etag_prefix, strlen(etag_prefix)) == 0) { + strncpy(headers->etag, buffer + strlen(etag_prefix), n_items - strlen(etag_prefix) - 2); // Remove CRLF + } + + const char * last_modified_prefix = "last-modified: "; + if (strncmp(buffer, last_modified_prefix, strlen(last_modified_prefix)) == 0) { + strncpy(headers->last_modified, buffer + strlen(last_modified_prefix), + n_items - strlen(last_modified_prefix) - 2); // Remove CRLF + } + return n_items; + }; + + curl_easy_setopt(curl, CURLOPT_NOBODY, 1L); // will trigger the HEAD verb + curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 1L); // hide head request progress + curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback)); + curl_easy_setopt(curl, CURLOPT_HEADERDATA, &headers); + + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + curl_easy_cleanup(curl); + fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); + return NULL; + } + + long http_code = 0; + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); + if (http_code != 200) { + // HEAD not supported, we don't know if the file has changed + // force trigger downloading + file_exists = false; + fprintf(stderr, "%s: HEAD invalid http status code received: %ld\n", __func__, http_code); + } + } + + // If the ETag or the Last-Modified headers are different: trigger a new download + if (!file_exists || strcmp(etag, headers.etag) != 0 || strcmp(last_modified, headers.last_modified) != 0) { + char path_model_temporary[LLAMA_CURL_MAX_PATH_LENGTH] = {0}; + snprintf(path_model_temporary, sizeof(path_model_temporary), "%s.downloadInProgress", path_model); + if (file_exists) { + fprintf(stderr, "%s: deleting previous downloaded model file: %s\n", __func__, path_model); + if (remove(path_model) != 0) { + curl_easy_cleanup(curl); + fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path_model); + return NULL; + } + } + + // Set the output file + auto * outfile = fopen(path_model_temporary, "wb"); + if (!outfile) { + curl_easy_cleanup(curl); + fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path_model); + return NULL; + } + + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd); + auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t { + return fwrite(data, size, nmemb, (FILE *)fd); + }; + curl_easy_setopt(curl, CURLOPT_NOBODY, 0L); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback)); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile); + + // display download progress + curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); + + // start the download + fprintf(stderr, "%s: downloading model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, + model_url, path_model, headers.etag, headers.last_modified); + auto res = curl_easy_perform(curl); + if (res != CURLE_OK) { + fclose(outfile); + curl_easy_cleanup(curl); + fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); + return NULL; + } + + long http_code = 0; + curl_easy_getinfo (curl, CURLINFO_RESPONSE_CODE, &http_code); + if (http_code < 200 || http_code >= 400) { + fclose(outfile); + curl_easy_cleanup(curl); + fprintf(stderr, "%s: invalid http status code received: %ld\n", __func__, http_code); + return NULL; + } + + // Clean up + fclose(outfile); + + // Write the new ETag to the .etag file + if (strlen(headers.etag) > 0) { + auto * etag_file = fopen(etag_path, "w"); + if (etag_file) { + fputs(headers.etag, etag_file); + fclose(etag_file); + fprintf(stderr, "%s: model etag saved %s: %s\n", __func__, etag_path, headers.etag); + } + } + + // Write the new lastModified to the .etag file + if (strlen(headers.last_modified) > 0) { + auto * last_modified_file = fopen(last_modified_path, "w"); + if (last_modified_file) { + fputs(headers.last_modified, last_modified_file); + fclose(last_modified_file); + fprintf(stderr, "%s: model last modified saved %s: %s\n", __func__, last_modified_path, + headers.last_modified); + } + } + + if (rename(path_model_temporary, path_model) != 0) { + curl_easy_cleanup(curl); + fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_model_temporary, path_model); + return NULL; + } + } + + curl_easy_cleanup(curl); + + return llama_load_model_from_file(path_model, params); +} + +#else + +struct llama_model * llama_load_model_from_url(const char * /*model_url*/, const char * /*path_model*/, + struct llama_model_params /*params*/) { + fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); + return nullptr; +} + +#endif // LLAMA_USE_CURL + std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) { auto mparams = llama_model_params_from_gpt_params(params); - llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); + llama_model * model = nullptr; + if (!params.model_url.empty()) { + model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams); + } else { + model = llama_load_model_from_file(params.model.c_str(), mparams); + } if (model == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); return std::make_tuple(nullptr, nullptr); |