diff options
author | Amir <amir_zia@outlook.com> | 2024-05-21 17:13:12 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-21 17:13:12 +0300 |
commit | 11474e756de3f56b760986e73086d40e787e52f8 (patch) | |
tree | ffb1c5369b3e7e8f128a114c7a7f1b5899376ac9 /common | |
parent | d8ee90222791afff2ab666ded4cb6195fd94cced (diff) |
examples: cache hf model when --model not provided (#7353)
* examples: cache hf model when --model not provided
* examples: cache hf model when --model not provided
* examples: cache hf model when --model not provided
* examples: cache hf model when --model not provided
* examples: cache hf model when --model not provided
Diffstat (limited to 'common')
-rw-r--r-- | common/common.cpp | 32 | ||||
-rw-r--r-- | common/common.h | 1 |
2 files changed, 32 insertions, 1 deletions
diff --git a/common/common.cpp b/common/common.cpp index e624fc7f..ae11650b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1354,7 +1354,12 @@ void gpt_params_handle_model_default(gpt_params & params) { } params.hf_file = params.model; } else if (params.model.empty()) { - params.model = "models/" + string_split(params.hf_file, '/').back(); + std::string cache_directory = get_cache_directory(); + const bool success = create_directory_with_parents(cache_directory); + if (!success) { + throw std::runtime_error("failed to create cache directory: " + cache_directory); + } + params.model = cache_directory + string_split(params.hf_file, '/').back(); } } else if (!params.model_url.empty()) { if (params.model.empty()) { @@ -2516,6 +2521,31 @@ bool create_directory_with_parents(const std::string & path) { #endif // _WIN32 } +std::string get_cache_directory() { + std::string cache_directory = ""; + if (getenv("LLAMA_CACHE")) { + cache_directory = std::getenv("LLAMA_CACHE"); + if (cache_directory.back() != DIRECTORY_SEPARATOR) { + cache_directory += DIRECTORY_SEPARATOR; + } + } else { +#ifdef __linux__ + if (std::getenv("XDG_CACHE_HOME")) { + cache_directory = std::getenv("XDG_CACHE_HOME"); + } else { + cache_directory = std::getenv("HOME") + std::string("/.cache/"); + } +#elif defined(__APPLE__) + cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); +#elif defined(_WIN32) + cache_directory = std::getenv("APPDATA"); +#endif // __linux__ + cache_directory += "llama.cpp"; + cache_directory += DIRECTORY_SEPARATOR; + } + return cache_directory; +} + void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data) { if (data.empty()) { fprintf(stream, "%s:\n", prop_name); diff --git a/common/common.h b/common/common.h index 566490e2..a8e5e50e 100644 --- a/common/common.h +++ b/common/common.h @@ -281,6 +281,7 @@ bool llama_should_add_bos_token(const llama_model * model); // bool create_directory_with_parents(const std::string & path); +std::string get_cache_directory(); void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data); void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data); void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data); |