diff options
Diffstat (limited to 'common/common.cpp')
-rw-r--r-- | common/common.cpp | 73 |
1 files changed, 72 insertions, 1 deletions
diff --git a/common/common.cpp b/common/common.cpp index 3e2df6e3..7d983a45 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -16,6 +16,7 @@ #include <unordered_set> #include <vector> #include <cinttypes> +#include <codecvt> #if defined(__APPLE__) && defined(__MACH__) #include <sys/types.h> @@ -27,7 +28,6 @@ #ifndef NOMINMAX # define NOMINMAX #endif -#include <codecvt> #include <locale> #include <windows.h> #include <fcntl.h> @@ -1500,6 +1500,77 @@ std::string gpt_random_prompt(std::mt19937 & rng) { GGML_UNREACHABLE(); } +// Validate if a filename is safe to use +// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function +bool validate_file_name(const std::string & filename) { + if (!filename.length()) { + // Empty filename invalid + return false; + } + if (filename.length() > 255) { + // Limit at common largest possible filename on Linux filesystems + // to avoid unnecessary further validation + // (On systems with smaller limits it will be caught by the OS) + return false; + } + + std::u32string filename_utf32; + try { + std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter; + filename_utf32 = converter.from_bytes(filename); + + // If the reverse conversion mismatches, it means overlong UTF-8 sequences were used, + // or invalid encodings were encountered. Reject such attempts + std::string filename_reencoded = converter.to_bytes(filename_utf32); + if (filename_reencoded != filename) { + return false; + } + } catch (const std::exception &) { + return false; + } + + // Check for forbidden codepoints: + // - Control characters + // - Unicode equivalents of illegal characters + // - UTF-16 surrogate pairs + // - UTF-8 replacement character + // - Byte order mark (BOM) + // - Illegal characters: / \ : * ? " < > | + for (char32_t c : filename_utf32) { + if (c <= 0x1F // Control characters (C0) + || c == 0x7F // Control characters (DEL) + || (c >= 0x80 && c <= 0x9F) // Control characters (C1) + || c == 0xFF0E // Fullwidth Full Stop (period equivalent) + || c == 0x2215 // Division Slash (forward slash equivalent) + || c == 0x2216 // Set Minus (backslash equivalent) + || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs + || c == 0xFFFD // Replacement Character (UTF-8) + || c == 0xFEFF // Byte Order Mark (BOM) + || c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters + || c == '?' || c == '"' || c == '<' || c == '>' || c == '|') { + return false; + } + } + + // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename + // Unicode and other whitespace is not affected, only 0x20 space + if (filename.front() == ' ' || filename.back() == ' ' || filename.back() == '.') { + return false; + } + + // Reject any ".." (currently stricter than necessary, it should be fine to just check for == ".." instead) + if (filename.find("..") != std::string::npos) { + return false; + } + + // Reject "." + if (filename == ".") { + return false; + } + + return true; +} + // // String utils // |