summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/common.cpp20
-rw-r--r--common/common.h1
2 files changed, 13 insertions, 8 deletions
diff --git a/common/common.cpp b/common/common.cpp
index d2a8bb69..1591790e 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -200,19 +200,13 @@ void gpt_params_handle_model_default(gpt_params & params) {
}
params.hf_file = params.model;
} else if (params.model.empty()) {
- std::string cache_directory = fs_get_cache_directory();
- const bool success = fs_create_directory_with_parents(cache_directory);
- if (!success) {
- throw std::runtime_error("failed to create cache directory: " + cache_directory);
- }
- params.model = cache_directory + string_split(params.hf_file, '/').back();
+ params.model = fs_get_cache_file(string_split(params.hf_file, '/').back());
}
} else if (!params.model_url.empty()) {
if (params.model.empty()) {
auto f = string_split(params.model_url, '#').front();
f = string_split(f, '?').front();
- f = string_split(f, '/').back();
- params.model = "models/" + f;
+ params.model = fs_get_cache_file(string_split(f, '/').back());
}
} else if (params.model.empty()) {
params.model = DEFAULT_MODEL_PATH;
@@ -2279,6 +2273,16 @@ std::string fs_get_cache_directory() {
return ensure_trailing_slash(cache_directory);
}
+std::string fs_get_cache_file(const std::string & filename) {
+ GGML_ASSERT(filename.find(DIRECTORY_SEPARATOR) == std::string::npos);
+ std::string cache_directory = fs_get_cache_directory();
+ const bool success = fs_create_directory_with_parents(cache_directory);
+ if (!success) {
+ throw std::runtime_error("failed to create cache directory: " + cache_directory);
+ }
+ return cache_directory + filename;
+}
+
//
// Model utils
diff --git a/common/common.h b/common/common.h
index 038f9084..2345d855 100644
--- a/common/common.h
+++ b/common/common.h
@@ -277,6 +277,7 @@ bool fs_validate_filename(const std::string & filename);
bool fs_create_directory_with_parents(const std::string & path);
std::string fs_get_cache_directory();
+std::string fs_get_cache_file(const std::string & filename);
//
// Model utils