summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
authorAmir <amir_zia@outlook.com>2024-05-21 17:13:12 +0300
committerGitHub <noreply@github.com>2024-05-21 17:13:12 +0300
commit11474e756de3f56b760986e73086d40e787e52f8 (patch)
treeffb1c5369b3e7e8f128a114c7a7f1b5899376ac9 /common
parentd8ee90222791afff2ab666ded4cb6195fd94cced (diff)
examples: cache hf model when --model not provided (#7353)
* examples: cache hf model when --model not provided * examples: cache hf model when --model not provided * examples: cache hf model when --model not provided * examples: cache hf model when --model not provided * examples: cache hf model when --model not provided
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp32
-rw-r--r--common/common.h1
2 files changed, 32 insertions, 1 deletions
diff --git a/common/common.cpp b/common/common.cpp
index e624fc7f..ae11650b 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1354,7 +1354,12 @@ void gpt_params_handle_model_default(gpt_params & params) {
}
params.hf_file = params.model;
} else if (params.model.empty()) {
- params.model = "models/" + string_split(params.hf_file, '/').back();
+ std::string cache_directory = get_cache_directory();
+ const bool success = create_directory_with_parents(cache_directory);
+ if (!success) {
+ throw std::runtime_error("failed to create cache directory: " + cache_directory);
+ }
+ params.model = cache_directory + string_split(params.hf_file, '/').back();
}
} else if (!params.model_url.empty()) {
if (params.model.empty()) {
@@ -2516,6 +2521,31 @@ bool create_directory_with_parents(const std::string & path) {
#endif // _WIN32
}
+std::string get_cache_directory() {
+ std::string cache_directory = "";
+ if (getenv("LLAMA_CACHE")) {
+ cache_directory = std::getenv("LLAMA_CACHE");
+ if (cache_directory.back() != DIRECTORY_SEPARATOR) {
+ cache_directory += DIRECTORY_SEPARATOR;
+ }
+ } else {
+#ifdef __linux__
+ if (std::getenv("XDG_CACHE_HOME")) {
+ cache_directory = std::getenv("XDG_CACHE_HOME");
+ } else {
+ cache_directory = std::getenv("HOME") + std::string("/.cache/");
+ }
+#elif defined(__APPLE__)
+ cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
+#elif defined(_WIN32)
+ cache_directory = std::getenv("APPDATA");
+#endif // __linux__
+ cache_directory += "llama.cpp";
+ cache_directory += DIRECTORY_SEPARATOR;
+ }
+ return cache_directory;
+}
+
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data) {
if (data.empty()) {
fprintf(stream, "%s:\n", prop_name);
diff --git a/common/common.h b/common/common.h
index 566490e2..a8e5e50e 100644
--- a/common/common.h
+++ b/common/common.h
@@ -281,6 +281,7 @@ bool llama_should_add_bos_token(const llama_model * model);
//
bool create_directory_with_parents(const std::string & path);
+std::string get_cache_directory();
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data);
void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data);
void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data);