summaryrefslogtreecommitdiff
path: root/common/common.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'common/common.cpp')
-rw-r--r--common/common.cpp73
1 files changed, 72 insertions, 1 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 3e2df6e3..7d983a45 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -16,6 +16,7 @@
#include <unordered_set>
#include <vector>
#include <cinttypes>
+#include <codecvt>
#if defined(__APPLE__) && defined(__MACH__)
#include <sys/types.h>
@@ -27,7 +28,6 @@
#ifndef NOMINMAX
# define NOMINMAX
#endif
-#include <codecvt>
#include <locale>
#include <windows.h>
#include <fcntl.h>
@@ -1500,6 +1500,77 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
GGML_UNREACHABLE();
}
+// Validate if a filename is safe to use
+// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
+bool validate_file_name(const std::string & filename) {
+ if (!filename.length()) {
+ // Empty filename invalid
+ return false;
+ }
+ if (filename.length() > 255) {
+ // Limit at common largest possible filename on Linux filesystems
+ // to avoid unnecessary further validation
+ // (On systems with smaller limits it will be caught by the OS)
+ return false;
+ }
+
+ std::u32string filename_utf32;
+ try {
+ std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
+ filename_utf32 = converter.from_bytes(filename);
+
+ // If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
+ // or invalid encodings were encountered. Reject such attempts
+ std::string filename_reencoded = converter.to_bytes(filename_utf32);
+ if (filename_reencoded != filename) {
+ return false;
+ }
+ } catch (const std::exception &) {
+ return false;
+ }
+
+ // Check for forbidden codepoints:
+ // - Control characters
+ // - Unicode equivalents of illegal characters
+ // - UTF-16 surrogate pairs
+ // - UTF-8 replacement character
+ // - Byte order mark (BOM)
+ // - Illegal characters: / \ : * ? " < > |
+ for (char32_t c : filename_utf32) {
+ if (c <= 0x1F // Control characters (C0)
+ || c == 0x7F // Control characters (DEL)
+ || (c >= 0x80 && c <= 0x9F) // Control characters (C1)
+ || c == 0xFF0E // Fullwidth Full Stop (period equivalent)
+ || c == 0x2215 // Division Slash (forward slash equivalent)
+ || c == 0x2216 // Set Minus (backslash equivalent)
+ || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
+ || c == 0xFFFD // Replacement Character (UTF-8)
+ || c == 0xFEFF // Byte Order Mark (BOM)
+ || c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
+ || c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
+ return false;
+ }
+ }
+
+ // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
+ // Unicode and other whitespace is not affected, only 0x20 space
+ if (filename.front() == ' ' || filename.back() == ' ' || filename.back() == '.') {
+ return false;
+ }
+
+ // Reject any ".." (currently stricter than necessary, it should be fine to just check for == ".." instead)
+ if (filename.find("..") != std::string::npos) {
+ return false;
+ }
+
+ // Reject "."
+ if (filename == ".") {
+ return false;
+ }
+
+ return true;
+}
+
//
// String utils
//