summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-08-21 23:07:43 +0300
committerGitHub <noreply@github.com>2023-08-21 23:07:43 +0300
commit6381d4e110bd0ec02843a60bbeb8b6fc37a9ace9 (patch)
tree15f5b726f864ad0913bc8dcf6ea08b90ecc7ada9 /common
parentdadbed99e65252d79f81101a392d0d6497b86caa (diff)
gguf : new file format with flexible meta data (beta) (#2398)
* gguf : first API pass * gguf : read header + meta data * gguf : read tensor info * gguf : initial model loading - not tested * gguf : add gguf_get_tensor_name() * gguf : do not support passing existing ggml_context to gguf_init * gguf : simplify gguf_get_val * gguf : gguf.c is now part of ggml.c * gguf : read / write sample models * gguf : add comments * refactor : reduce code duplication and better API (#2415) * gguf : expose the gguf_type enum through the API for now * gguf : add array support * gguf.py : some code style changes * convert.py : start a new simplified implementation by removing old stuff * convert.py : remove GGML vocab + other obsolete stuff * GGUF : write tensor (#2426) * WIP: Write tensor * GGUF : Support writing tensors in Python * refactor : rm unused import and upd todos * fix : fix errors upd writing example * rm example.gguf * gitignore *.gguf * undo formatting * gguf : add gguf_find_key (#2438) * gguf.cpp : find key example * ggml.h : add gguf_find_key * ggml.c : add gguf_find_key * gguf : fix writing tensors * gguf : do not hardcode tensor names to read * gguf : write sample tensors to read * gguf : add tokenization constants * quick and dirty conversion example * gguf : fix writing gguf arrays * gguf : write tensors one by one and code reuse * gguf : fix writing gguf arrays * gguf : write tensors one by one * gguf : write tensors one by one * gguf : write tokenizer data * gguf : upd gguf conversion script * Update convert-llama-h5-to-gguf.py * gguf : handle already encoded string * ggml.h : get array str and f32 * ggml.c : get arr str and f32 * gguf.py : support any type * Update convert-llama-h5-to-gguf.py * gguf : fix set is not subscriptable * gguf : update convert-llama-h5-to-gguf.py * constants.py : add layer norm eps * gguf.py : add layer norm eps and merges * ggml.h : increase GGML_MAX_NAME to 64 * ggml.c : add gguf_get_arr_n * Update convert-llama-h5-to-gguf.py * add gptneox gguf example * Makefile : add gptneox gguf example * Update convert-llama-h5-to-gguf.py * add gptneox gguf example * Update convert-llama-h5-to-gguf.py * Update convert-gptneox-h5-to-gguf.py * Update convert-gptneox-h5-to-gguf.py * Update convert-llama-h5-to-gguf.py * gguf : support custom alignment value * gguf : fix typo in function call * gguf : mmap tensor data example * fix : update convert-llama-h5-to-gguf.py * Update convert-llama-h5-to-gguf.py * convert-gptneox-h5-to-gguf.py : Special tokens * gptneox-main.cpp : special tokens * Update gptneox-main.cpp * constants.py : special tokens * gguf.py : accumulate kv and tensor info data + special tokens * convert-gptneox-h5-to-gguf.py : accumulate kv and ti + special tokens * gguf : gguf counterpart of llama-util.h * gguf-util.h : update note * convert-llama-h5-to-gguf.py : accumulate kv / ti + special tokens * convert-llama-h5-to-gguf.py : special tokens * Delete gptneox-common.cpp * Delete gptneox-common.h * convert-gptneox-h5-to-gguf.py : gpt2bpe tokenizer * gptneox-main.cpp : gpt2 bpe tokenizer * gpt2 bpe tokenizer (handles merges and unicode) * Makefile : remove gptneox-common * gguf.py : bytesarray for gpt2bpe tokenizer * cmpnct_gpt2bpe.hpp : comments * gguf.py : use custom alignment if present * gguf : minor stuff * Update gptneox-main.cpp * map tensor names * convert-gptneox-h5-to-gguf.py : map tensor names * convert-llama-h5-to-gguf.py : map tensor names * gptneox-main.cpp : map tensor names * gguf : start implementing libllama in GGUF (WIP) * gguf : start implementing libllama in GGUF (WIP) * rm binary commited by mistake * upd .gitignore * gguf : calculate n_mult * gguf : inference with 7B model working (WIP) * gguf : rm deprecated function * gguf : start implementing gguf_file_saver (WIP) * gguf : start implementing gguf_file_saver (WIP) * gguf : start implementing gguf_file_saver (WIP) * gguf : add gguf_get_kv_type * gguf : add gguf_get_kv_type * gguf : write metadata in gguf_file_saver (WIP) * gguf : write metadata in gguf_file_saver (WIP) * gguf : write metadata in gguf_file_saver * gguf : rm references to old file formats * gguf : shorter name for member variable * gguf : rm redundant method * gguf : get rid of n_mult, read n_ff from file * Update gguf_tensor_map.py * Update gptneox-main.cpp * gguf : rm references to old file magics * gguf : start implementing quantization (WIP) * gguf : start implementing quantization (WIP) * gguf : start implementing quantization (WIP) * gguf : start implementing quantization (WIP) * gguf : start implementing quantization (WIP) * gguf : start implementing quantization (WIP) * gguf : quantization is working * gguf : roper closing of file * gguf.py : no need to convert tensors twice * convert-gptneox-h5-to-gguf.py : no need to convert tensors twice * convert-llama-h5-to-gguf.py : no need to convert tensors twice * convert-gptneox-h5-to-gguf.py : simplify nbytes * convert-llama-h5-to-gguf.py : simplify nbytes * gptneox-main.cpp : n_layer --> n_block * constants.py : n_layer --> n_block * gguf.py : n_layer --> n_block * convert-gptneox-h5-to-gguf.py : n_layer --> n_block * convert-llama-h5-to-gguf.py : n_layer --> n_block * gptneox-main.cpp : n_layer --> n_block * Update gguf_tensor_map.py * convert-gptneox-h5-to-gguf.py : load model in parts to save memory * convert-llama-h5-to-gguf.py : load model in parts to save memory * convert : write more metadata for LLaMA * convert : rm quantization version * convert-gptneox-h5-to-gguf.py : add file_type key * gptneox-main.cpp : add file_type key * fix conflicts * gguf : add todos and comments * convert-gptneox-h5-to-gguf.py : tensor name map changes * Create gguf_namemap.py : tensor name map changes * Delete gguf_tensor_map.py * gptneox-main.cpp : tensor name map changes * convert-llama-h5-to-gguf.py : fixes * gguf.py : dont add empty strings * simple : minor style changes * gguf : use UNIX line ending * Create convert-llama-7b-pth-to-gguf.py * llama : sync gguf-llama.cpp with latest llama.cpp (#2608) * llama : sync gguf-llama.cpp with latest llama.cpp * minor : indentation + assert * llama : refactor gguf_buffer and gguf_ctx_buffer * llama : minor * gitignore : add gptneox-main * llama : tokenizer fixes (#2549) * Merge tokenizer fixes into the gguf branch. * Add test vocabularies * convert : update convert-new.py with tokenizer fixes (#2614) * Merge tokenizer fixes into the gguf branch. * Add test vocabularies * Adapt convert-new.py (and fix a clang-cl compiler error on windows) * llama : sync gguf-llama with llama (#2613) * llama : sync gguf-llama with llama * tests : fix build + warnings (test-tokenizer-1 still fails) * tests : fix wstring_convert * convert : fix layer names * llama : sync gguf-llama.cpp * convert : update HF converter to new tokenizer voodoo magics * llama : update tokenizer style * convert-llama-h5-to-gguf.py : add token types * constants.py : add token types * gguf.py : add token types * convert-llama-7b-pth-to-gguf.py : add token types * gguf-llama.cpp : fix n_head_kv * convert-llama-h5-to-gguf.py : add 70b gqa support * gguf.py : add tensor data layout * convert-llama-h5-to-gguf.py : add tensor data layout * convert-llama-7b-pth-to-gguf.py : add tensor data layout * gptneox-main.cpp : add tensor data layout * convert-llama-h5-to-gguf.py : clarify the reverse permute * llama : refactor model loading code (#2620) * llama : style formatting + remove helper methods * llama : fix quantization using gguf tool * llama : simplify gguf_file_saver * llama : fix method names * llama : simplify write_header() * llama : no need to pass full file loader to the file saver just gguf_ctx * llama : gguf_file_saver write I32 * llama : refactor tensor names (#2622) * gguf: update tensor names searched in quantization * gguf : define tensor names as constants * gguf : initial write API (not tested yet) * gguf : write to file API (not tested) * gguf : initial write API ready + example * gguf : fix header write * gguf : fixes + simplify example + add ggml_nbytes_pad() * gguf : minor * llama : replace gguf_file_saver with new gguf write API * gguf : streaming support when writing files * gguf : remove oboslete write methods * gguf : remove obosolete gguf_get_arr_xxx API * llama : simplify gguf_file_loader * llama : move hparams and vocab from gguf_file_loader to llama_model_loader * llama : merge gguf-util.h in llama.cpp * llama : reorder definitions in .cpp to match .h * llama : minor simplifications * llama : refactor llama_model_loader (WIP) wip : remove ggml_ctx from llama_model_loader wip : merge gguf_file_loader in llama_model_loader * llama : fix shape prints * llama : fix Windows build + fix norm_rms_eps key * llama : throw error on missing KV paris in model meta data * llama : improve printing + log meta data * llama : switch print order of meta data --------- Co-authored-by: M. Yusuf Sarıgöz <yusufsarigoz@gmail.com> * gguf : deduplicate (#2629) * gguf : better type names * dedup : CPU + Metal is working * ggml : fix warnings about unused results * llama.cpp : fix line feed and compiler warning * llama : fix strncpy warning + note token_to_str does not write null * llama : restore the original load/save session implementation Will migrate this to GGUF in the future * convert-llama-h5-to-gguf.py : support alt ctx param name * ggml : assert when using ggml_mul with non-F32 src1 * examples : dedup simple --------- Co-authored-by: klosax <131523366+klosax@users.noreply.github.com> * gguf.py : merge all files in gguf.py * convert-new.py : pick #2427 for HF 70B support * examples/gguf : no need to keep q option for quantization any more * llama.cpp : print actual model size * llama.cpp : use ggml_elements() * convert-new.py : output gguf (#2635) * convert-new.py : output gguf (WIP) * convert-new.py : add gguf key-value pairs * llama : add hparams.ctx_train + no longer print ftype * convert-new.py : minor fixes * convert-new.py : vocab-only option should work now * llama : fix tokenizer to use llama_char_to_byte * tests : add new ggml-vocab-llama.gguf * convert-new.py : tensor name mapping * convert-new.py : add map for skipping tensor serialization * convert-new.py : convert script now works * gguf.py : pick some of the refactoring from #2644 * convert-new.py : minor fixes * convert.py : update to support GGUF output * Revert "ci : disable CI temporary to not waste energy" This reverts commit 7e82d25f40386540c2c15226300ad998ecd871ea. * convert.py : n_head_kv optional and .gguf file extension * convert.py : better always have n_head_kv and default it to n_head * llama : sync with recent PRs on master * editorconfig : ignore models folder ggml-ci * ci : update ".bin" to ".gguf" extension ggml-ci * llama : fix llama_model_loader memory leak * gptneox : move as a WIP example * llama : fix lambda capture ggml-ci * ggml : fix bug in gguf_set_kv ggml-ci * common.h : .bin --> .gguf * quantize-stats.cpp : .bin --> .gguf * convert.py : fix HF tensor permuting / unpacking ggml-ci * llama.cpp : typo * llama : throw error if gguf fails to init from file ggml-ci * llama : fix tensor name grepping during quantization ggml-ci * gguf.py : write tensors in a single pass (#2644) * gguf : single pass for writing tensors + refactoring writer * gguf : single pass for writing tensors + refactoring writer * gguf : single pass for writing tensors + refactoring writer * gguf : style fixes in simple conversion script * gguf : refactor gptneox conversion script * gguf : rename h5 to hf (for HuggingFace) * gguf : refactor pth to gguf conversion script * gguf : rm file_type key and method * gguf.py : fix vertical alignment * gguf.py : indentation --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * convert-gptneox-hf-to-gguf.py : fixes * gguf.py : gptneox mapping * convert-llama-hf-to-gguf.py : fixes * convert-llama-7b-pth-to-gguf.py : fixes * ggml.h : reverse GGUF_MAGIC * gguf.py : reverse GGUF_MAGIC * test-tokenizer-0.cpp : fix warning * llama.cpp : print kv general.name * llama.cpp : get special token kv and linefeed token id * llama : print number of tensors per type + print arch + style * tests : update vocab file with new magic * editorconfig : fix whitespaces * llama : re-order functions * llama : remove C++ API + reorganize common source in /common dir * llama : minor API updates * llama : avoid hardcoded special tokens * llama : fix MPI build ggml-ci * llama : introduce enum llama_vocab_type + remove hardcoded string constants * convert-falcon-hf-to-gguf.py : falcon HF --> gguf conversion, not tested * falcon-main.cpp : falcon inference example * convert-falcon-hf-to-gguf.py : remove extra kv * convert-gptneox-hf-to-gguf.py : remove extra kv * convert-llama-7b-pth-to-gguf.py : remove extra kv * convert-llama-hf-to-gguf.py : remove extra kv * gguf.py : fix for falcon 40b * falcon-main.cpp : fix for falcon 40b * convert-falcon-hf-to-gguf.py : update ref * convert-falcon-hf-to-gguf.py : add tensor data layout * cmpnct_gpt2bpe.hpp : fixes * falcon-main.cpp : fixes * gptneox-main.cpp : fixes * cmpnct_gpt2bpe.hpp : remove non-general stuff * Update examples/server/README.md Co-authored-by: slaren <slarengh@gmail.com> * cmpnct_gpt2bpe.hpp : cleanup * convert-llama-hf-to-gguf.py : special tokens * convert-llama-7b-pth-to-gguf.py : special tokens * convert-permute-debug.py : permute debug print * convert-permute-debug-master.py : permute debug for master * convert-permute-debug.py : change permute type of attn_q * convert.py : 70b model working (change attn_q permute) * Delete convert-permute-debug-master.py * Delete convert-permute-debug.py * convert-llama-hf-to-gguf.py : fix attn_q permute * gguf.py : fix rope scale kv * convert-llama-hf-to-gguf.py : rope scale and added tokens * convert-llama-7b-pth-to-gguf.py : rope scale and added tokens * llama.cpp : use rope scale kv * convert-llama-7b-pth-to-gguf.py : rope scale fix * convert-llama-hf-to-gguf.py : rope scale fix * py : fix whitespace * gguf : add Python script to convert GGMLv3 LLaMA models to GGUF (#2682) * First pass at converting GGMLv3 LLaMA models to GGUF * Cleanups, better output during conversion * Fix vocab space conversion logic * More vocab conversion fixes * Add description to converted GGUF files * Improve help text, expand warning * Allow specifying name and description for output GGUF * Allow overriding vocab and hyperparams from original model metadata * Use correct params override var name * Fix wrong type size for Q8_K Better handling of original style metadata * Set default value for gguf add_tensor raw_shape KW arg * llama : improve token type support (#2668) * Merge tokenizer fixes into the gguf branch. * Add test vocabularies * Adapt convert-new.py (and fix a clang-cl compiler error on windows) * Improved tokenizer test But does it work on MacOS? * Improve token type support - Added @klosax code to convert.py - Improved token type support in vocabulary * Exclude platform dependent tests * More sentencepiece compatibility by eliminating magic numbers * Restored accidentally removed comment * llama : add API for token type ggml-ci * tests : use new tokenizer type API (#2692) * Merge tokenizer fixes into the gguf branch. * Add test vocabularies * Adapt convert-new.py (and fix a clang-cl compiler error on windows) * Improved tokenizer test But does it work on MacOS? * Improve token type support - Added @klosax code to convert.py - Improved token type support in vocabulary * Exclude platform dependent tests * More sentencepiece compatibility by eliminating magic numbers * Restored accidentally removed comment * Improve commentary * Use token type API in test-tokenizer-1.cpp * py : cosmetics * readme : add notice about new file format ggml-ci --------- Co-authored-by: M. Yusuf Sarıgöz <yusufsarigoz@gmail.com> Co-authored-by: klosax <131523366+klosax@users.noreply.github.com> Co-authored-by: goerch <jhr.walter@t-online.de> Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com>
Diffstat (limited to 'common')
-rw-r--r--common/CMakeLists.txt20
-rw-r--r--common/common.cpp767
-rw-r--r--common/common.h130
-rw-r--r--common/console.cpp500
-rw-r--r--common/console.h19
-rw-r--r--common/grammar-parser.cpp423
-rw-r--r--common/grammar-parser.h29
7 files changed, 1888 insertions, 0 deletions
diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt
new file mode 100644
index 00000000..dead5611
--- /dev/null
+++ b/common/CMakeLists.txt
@@ -0,0 +1,20 @@
+# common
+
+set(TARGET common)
+
+add_library(${TARGET} OBJECT
+ common.h
+ common.cpp
+ console.h
+ console.cpp
+ grammar-parser.h
+ grammar-parser.cpp
+ )
+
+if (BUILD_SHARED_LIBS)
+ set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
+endif()
+
+target_include_directories(${TARGET} PUBLIC .)
+target_compile_features(${TARGET} PUBLIC cxx_std_11)
+target_link_libraries(${TARGET} PRIVATE llama)
diff --git a/common/common.cpp b/common/common.cpp
new file mode 100644
index 00000000..d7e1a572
--- /dev/null
+++ b/common/common.cpp
@@ -0,0 +1,767 @@
+#include "common.h"
+
+#include <cassert>
+#include <iostream>
+#include <cstring>
+#include <fstream>
+#include <string>
+#include <iterator>
+#include <algorithm>
+#include <sstream>
+#include <unordered_set>
+#include <regex>
+
+#if defined(__APPLE__) && defined(__MACH__)
+#include <sys/types.h>
+#include <sys/sysctl.h>
+#endif
+
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#define NOMINMAX
+#include <windows.h>
+#include <fcntl.h>
+#include <io.h>
+#else
+#include <sys/ioctl.h>
+#include <unistd.h>
+#endif
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+int32_t get_num_physical_cores() {
+#ifdef __linux__
+ // enumerate the set of thread siblings, num entries is num cores
+ std::unordered_set<std::string> siblings;
+ for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) {
+ std::ifstream thread_siblings("/sys/devices/system/cpu"
+ + std::to_string(cpu) + "/topology/thread_siblings");
+ if (!thread_siblings.is_open()) {
+ break; // no more cpus
+ }
+ std::string line;
+ if (std::getline(thread_siblings, line)) {
+ siblings.insert(line);
+ }
+ }
+ if (siblings.size() > 0) {
+ return static_cast<int32_t>(siblings.size());
+ }
+#elif defined(__APPLE__) && defined(__MACH__)
+ int32_t num_physical_cores;
+ size_t len = sizeof(num_physical_cores);
+ int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0);
+ if (result == 0) {
+ return num_physical_cores;
+ }
+ result = sysctlbyname("hw.physicalcpu", &num_physical_cores, &len, NULL, 0);
+ if (result == 0) {
+ return num_physical_cores;
+ }
+#elif defined(_WIN32)
+ //TODO: Implement
+#endif
+ unsigned int n_threads = std::thread::hardware_concurrency();
+ return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
+}
+
+void process_escapes(std::string& input) {
+ std::size_t input_len = input.length();
+ std::size_t output_idx = 0;
+
+ for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) {
+ if (input[input_idx] == '\\' && input_idx + 1 < input_len) {
+ switch (input[++input_idx]) {
+ case 'n': input[output_idx++] = '\n'; break;
+ case 'r': input[output_idx++] = '\r'; break;
+ case 't': input[output_idx++] = '\t'; break;
+ case '\'': input[output_idx++] = '\''; break;
+ case '\"': input[output_idx++] = '\"'; break;
+ case '\\': input[output_idx++] = '\\'; break;
+ default: input[output_idx++] = '\\';
+ input[output_idx++] = input[input_idx]; break;
+ }
+ } else {
+ input[output_idx++] = input[input_idx];
+ }
+ }
+
+ input.resize(output_idx);
+}
+
+bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
+ bool invalid_param = false;
+ bool escape_prompt = false;
+ std::string arg;
+ gpt_params default_params;
+ const std::string arg_prefix = "--";
+
+ for (int i = 1; i < argc; i++) {
+ arg = argv[i];
+ if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
+ std::replace(arg.begin(), arg.end(), '_', '-');
+ }
+
+ if (arg == "-s" || arg == "--seed") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.seed = std::stoul(argv[i]);
+ } else if (arg == "-t" || arg == "--threads") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_threads = std::stoi(argv[i]);
+ if (params.n_threads <= 0) {
+ params.n_threads = std::thread::hardware_concurrency();
+ }
+ } else if (arg == "-p" || arg == "--prompt") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.prompt = argv[i];
+ } else if (arg == "-e") {
+ escape_prompt = true;
+ } else if (arg == "--prompt-cache") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.path_prompt_cache = argv[i];
+ } else if (arg == "--prompt-cache-all") {
+ params.prompt_cache_all = true;
+ } else if (arg == "--prompt-cache-ro") {
+ params.prompt_cache_ro = true;
+ } else if (arg == "-f" || arg == "--file") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::ifstream file(argv[i]);
+ if (!file) {
+ fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+ invalid_param = true;
+ break;
+ }
+ std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
+ if (params.prompt.back() == '\n') {
+ params.prompt.pop_back();
+ }
+ } else if (arg == "-n" || arg == "--n-predict") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_predict = std::stoi(argv[i]);
+ } else if (arg == "--top-k") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.top_k = std::stoi(argv[i]);
+ } else if (arg == "-c" || arg == "--ctx-size") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_ctx = std::stoi(argv[i]);
+ } else if (arg == "--rope-freq-base") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.rope_freq_base = std::stof(argv[i]);
+ } else if (arg == "--rope-freq-scale") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.rope_freq_scale = std::stof(argv[i]);
+ } else if (arg == "--rope-scale") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.rope_freq_scale = 1.0f/std::stof(argv[i]);
+ } else if (arg == "--memory-f32") {
+ params.memory_f16 = false;
+ } else if (arg == "--top-p") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.top_p = std::stof(argv[i]);
+ } else if (arg == "--temp") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.temp = std::stof(argv[i]);
+ } else if (arg == "--tfs") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.tfs_z = std::stof(argv[i]);
+ } else if (arg == "--typical") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.typical_p = std::stof(argv[i]);
+ } else if (arg == "--repeat-last-n") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.repeat_last_n = std::stoi(argv[i]);
+ } else if (arg == "--repeat-penalty") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.repeat_penalty = std::stof(argv[i]);
+ } else if (arg == "--frequency-penalty") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.frequency_penalty = std::stof(argv[i]);
+ } else if (arg == "--presence-penalty") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.presence_penalty = std::stof(argv[i]);
+ } else if (arg == "--mirostat") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.mirostat = std::stoi(argv[i]);
+ } else if (arg == "--mirostat-lr") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.mirostat_eta = std::stof(argv[i]);
+ } else if (arg == "--mirostat-ent") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.mirostat_tau = std::stof(argv[i]);
+ } else if (arg == "--cfg-negative-prompt") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.cfg_negative_prompt = argv[i];
+ } else if (arg == "--cfg-negative-prompt-file") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::ifstream file(argv[i]);
+ if (!file) {
+ fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+ invalid_param = true;
+ break;
+ }
+ std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.cfg_negative_prompt));
+ if (params.cfg_negative_prompt.back() == '\n') {
+ params.cfg_negative_prompt.pop_back();
+ }
+ } else if (arg == "--cfg-scale") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.cfg_scale = std::stof(argv[i]);
+ } else if (arg == "-b" || arg == "--batch-size") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_batch = std::stoi(argv[i]);
+ params.n_batch = std::min(512, params.n_batch);
+ } else if (arg == "--keep") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_keep = std::stoi(argv[i]);
+ } else if (arg == "--chunks") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_chunks = std::stoi(argv[i]);
+ } else if (arg == "-m" || arg == "--model") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.model = argv[i];
+ } else if (arg == "-a" || arg == "--alias") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.model_alias = argv[i];
+ } else if (arg == "--lora") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.lora_adapter = argv[i];
+ params.use_mmap = false;
+ } else if (arg == "--lora-base") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.lora_base = argv[i];
+ } else if (arg == "-i" || arg == "--interactive") {
+ params.interactive = true;
+ } else if (arg == "--embedding") {
+ params.embedding = true;
+ } else if (arg == "--interactive-first") {
+ params.interactive_first = true;
+ } else if (arg == "-ins" || arg == "--instruct") {
+ params.instruct = true;
+ } else if (arg == "--multiline-input") {
+ params.multiline_input = true;
+ } else if (arg == "--simple-io") {
+ params.simple_io = true;
+ } else if (arg == "--color") {
+ params.use_color = true;
+ } else if (arg == "--mlock") {
+ params.use_mlock = true;
+ } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
+ params.n_gpu_layers = std::stoi(argv[i]);
+#else
+ fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
+ fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
+#endif
+ } else if (arg == "--main-gpu" || arg == "-mg") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+#ifdef GGML_USE_CUBLAS
+ params.main_gpu = std::stoi(argv[i]);
+#else
+ fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
+#endif
+ } else if (arg == "--tensor-split" || arg == "-ts") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+#ifdef GGML_USE_CUBLAS
+ std::string arg_next = argv[i];
+
+ // split string by , and /
+ const std::regex regex{R"([,/]+)"};
+ std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
+ std::vector<std::string> split_arg{it, {}};
+ GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
+
+ for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
+ if (i < split_arg.size()) {
+ params.tensor_split[i] = std::stof(split_arg[i]);
+ } else {
+ params.tensor_split[i] = 0.0f;
+ }
+ }
+#else
+ fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
+#endif // GGML_USE_CUBLAS
+ } else if (arg == "--mul-mat-q" || arg == "-mmq") {
+#ifdef GGML_USE_CUBLAS
+ params.mul_mat_q = true;
+#else
+ fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to use mul_mat_q kernels.\n");
+#endif // GGML_USE_CUBLAS
+ } else if (arg == "--low-vram" || arg == "-lv") {
+#ifdef GGML_USE_CUBLAS
+ params.low_vram = true;
+#else
+ fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
+#endif // GGML_USE_CUBLAS
+ } else if (arg == "--no-mmap") {
+ params.use_mmap = false;
+ } else if (arg == "--mtest") {
+ params.mem_test = true;
+ } else if (arg == "--numa") {
+ params.numa = true;
+ } else if (arg == "--export") {
+ params.export_cgraph = true;
+ } else if (arg == "--verbose-prompt") {
+ params.verbose_prompt = true;
+ } else if (arg == "-r" || arg == "--reverse-prompt") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.antiprompt.push_back(argv[i]);
+ } else if (arg == "--perplexity") {
+ params.perplexity = true;
+ } else if (arg == "--hellaswag") {
+ params.hellaswag = true;
+ } else if (arg == "--hellaswag-tasks") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.hellaswag_tasks = std::stoi(argv[i]);
+ } else if (arg == "--ignore-eos") {
+ params.ignore_eos = true;
+ } else if (arg == "--no-penalize-nl") {
+ params.penalize_nl = false;
+ } else if (arg == "-l" || arg == "--logit-bias") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::stringstream ss(argv[i]);
+ llama_token key;
+ char sign;
+ std::string value_str;
+ try {
+ if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
+ params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
+ } else {
+ throw std::exception();
+ }
+ } catch (const std::exception&) {
+ invalid_param = true;
+ break;
+ }
+ } else if (arg == "-h" || arg == "--help") {
+ gpt_print_usage(argc, argv, default_params);
+ exit(0);
+ } else if (arg == "--random-prompt") {
+ params.random_prompt = true;
+ } else if (arg == "--in-prefix-bos") {
+ params.input_prefix_bos = true;
+ } else if (arg == "--in-prefix") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.input_prefix = argv[i];
+ } else if (arg == "--in-suffix") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.input_suffix = argv[i];
+ } else if (arg == "--grammar") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.grammar = argv[i];
+ } else if (arg == "--grammar-file") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::ifstream file(argv[i]);
+ if (!file) {
+ fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+ invalid_param = true;
+ break;
+ }
+ std::copy(
+ std::istreambuf_iterator<char>(file),
+ std::istreambuf_iterator<char>(),
+ std::back_inserter(params.grammar)
+ );
+ } else {
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
+ gpt_print_usage(argc, argv, default_params);
+ exit(1);
+ }
+ }
+ if (invalid_param) {
+ fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
+ gpt_print_usage(argc, argv, default_params);
+ exit(1);
+ }
+ if (params.prompt_cache_all &&
+ (params.interactive || params.interactive_first ||
+ params.instruct)) {
+ fprintf(stderr, "error: --prompt-cache-all not supported in interactive mode yet\n");
+ gpt_print_usage(argc, argv, default_params);
+ exit(1);
+ }
+
+ if (escape_prompt) {
+ process_escapes(params.prompt);
+ process_escapes(params.input_prefix);
+ process_escapes(params.input_suffix);
+ }
+
+ return true;
+}
+
+void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
+ fprintf(stdout, "usage: %s [options]\n", argv[0]);
+ fprintf(stdout, "\n");
+ fprintf(stdout, "options:\n");
+ fprintf(stdout, " -h, --help show this help message and exit\n");
+ fprintf(stdout, " -i, --interactive run in interactive mode\n");
+ fprintf(stdout, " --interactive-first run in interactive mode and wait for input right away\n");
+ fprintf(stdout, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
+ fprintf(stdout, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
+ fprintf(stdout, " -r PROMPT, --reverse-prompt PROMPT\n");
+ fprintf(stdout, " halt generation at PROMPT, return control in interactive mode\n");
+ fprintf(stdout, " (can be specified more than once for multiple prompts).\n");
+ fprintf(stdout, " --color colorise output to distinguish prompt and user input from generations\n");
+ fprintf(stdout, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
+ fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
+ fprintf(stdout, " -p PROMPT, --prompt PROMPT\n");
+ fprintf(stdout, " prompt to start generation with (default: empty)\n");
+ fprintf(stdout, " -e process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
+ fprintf(stdout, " --prompt-cache FNAME file to cache prompt state for faster startup (default: none)\n");
+ fprintf(stdout, " --prompt-cache-all if specified, saves user input and generations to cache as well.\n");
+ fprintf(stdout, " not supported with --interactive or other interactive options\n");
+ fprintf(stdout, " --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n");
+ fprintf(stdout, " --random-prompt start with a randomized prompt.\n");
+ fprintf(stdout, " --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n");
+ fprintf(stdout, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");
+ fprintf(stdout, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n");
+ fprintf(stdout, " -f FNAME, --file FNAME\n");
+ fprintf(stdout, " prompt file to start generation.\n");
+ fprintf(stdout, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
+ fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
+ fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
+ fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
+ fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
+ fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
+ fprintf(stdout, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p);
+ fprintf(stdout, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n);
+ fprintf(stdout, " --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
+ fprintf(stdout, " --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
+ fprintf(stdout, " --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
+ fprintf(stdout, " --mirostat N use Mirostat sampling.\n");
+ fprintf(stdout, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
+ fprintf(stdout, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);
+ fprintf(stdout, " --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta);
+ fprintf(stdout, " --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau);
+ fprintf(stdout, " -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n");
+ fprintf(stdout, " modifies the likelihood of token appearing in the completion,\n");
+ fprintf(stdout, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
+ fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
+ fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n");
+ fprintf(stdout, " --grammar-file FNAME file to read grammar from\n");
+ fprintf(stdout, " --cfg-negative-prompt PROMPT\n");
+ fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
+ fprintf(stdout, " --cfg-negative-prompt-file FNAME\n");
+ fprintf(stdout, " negative prompt file to use for guidance. (default: empty)\n");
+ fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
+ fprintf(stdout, " --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale);
+ fprintf(stdout, " --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base);
+ fprintf(stdout, " --rope-freq-scale N RoPE frequency linear scaling factor, inverse of --rope-scale (default: %g)\n", params.rope_freq_scale);
+ fprintf(stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
+ fprintf(stdout, " --no-penalize-nl do not penalize newline token\n");
+ fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
+ fprintf(stdout, " not recommended: doubles context memory required and no measurable increase in quality\n");
+ fprintf(stdout, " --temp N temperature (default: %.1f)\n", (double)params.temp);
+ fprintf(stdout, " --perplexity compute perplexity over each ctx window of the prompt\n");
+ fprintf(stdout, " --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
+ fprintf(stdout, " --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
+ fprintf(stdout, " --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
+ fprintf(stdout, " --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
+ if (llama_mlock_supported()) {
+ fprintf(stdout, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
+ }
+ if (llama_mmap_supported()) {
+ fprintf(stdout, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
+ }
+ fprintf(stdout, " --numa attempt optimizations that help on some NUMA systems\n");
+ fprintf(stdout, " if run without this previously, it is recommended to drop the system page cache before using this\n");
+ fprintf(stdout, " see https://github.com/ggerganov/llama.cpp/issues/1437\n");
+#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
+ fprintf(stdout, " -ngl N, --n-gpu-layers N\n");
+ fprintf(stdout, " number of layers to store in VRAM\n");
+ fprintf(stdout, " -ts SPLIT --tensor-split SPLIT\n");
+ fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
+ fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" );
+ fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n" );
+ fprintf(stdout, " -mmq, --mul-mat-q use experimental mul_mat_q CUDA kernels instead of cuBLAS. TEMP!!!\n" );
+ fprintf(stdout, " Reduces VRAM usage by 700/970/1430 MiB for 7b/13b/33b but prompt processing speed\n" );
+ fprintf(stdout, " is still suboptimal, especially q2_K, q3_K, q5_K, and q6_K.\n" );
+#endif
+ fprintf(stdout, " --mtest compute maximum memory usage\n");
+ fprintf(stdout, " --export export the computation graph to 'llama.ggml'\n");
+ fprintf(stdout, " --verbose-prompt print prompt before generation\n");
+ fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
+ fprintf(stdout, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
+ fprintf(stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
+ fprintf(stdout, " -m FNAME, --model FNAME\n");
+ fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
+ fprintf(stdout, "\n");
+}
+
+std::string gpt_random_prompt(std::mt19937 & rng) {
+ const int r = rng() % 10;
+ switch (r) {
+ case 0: return "So";
+ case 1: return "Once upon a time";
+ case 2: return "When";
+ case 3: return "The";
+ case 4: return "After";
+ case 5: return "If";
+ case 6: return "import";
+ case 7: return "He";
+ case 8: return "She";
+ case 9: return "They";
+ default: return "To";
+ }
+
+ return "The";
+}
+
+//
+// Model utils
+//
+
+struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
+ auto lparams = llama_context_default_params();
+
+ lparams.n_ctx = params.n_ctx;
+ lparams.n_batch = params.n_batch;
+ lparams.n_gpu_layers = params.n_gpu_layers;
+ lparams.main_gpu = params.main_gpu;
+ lparams.tensor_split = params.tensor_split;
+ lparams.low_vram = params.low_vram;
+ lparams.mul_mat_q = params.mul_mat_q;
+ lparams.seed = params.seed;
+ lparams.f16_kv = params.memory_f16;
+ lparams.use_mmap = params.use_mmap;
+ lparams.use_mlock = params.use_mlock;
+ lparams.logits_all = params.perplexity;
+ lparams.embedding = params.embedding;
+ lparams.rope_freq_base = params.rope_freq_base;
+ lparams.rope_freq_scale = params.rope_freq_scale;
+
+ return lparams;
+}
+
+std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
+ auto lparams = llama_context_params_from_gpt_params(params);
+
+ llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
+ if (model == NULL) {
+ fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
+ return std::make_tuple(nullptr, nullptr);
+ }
+
+ llama_context * lctx = llama_new_context_with_model(model, lparams);
+ if (lctx == NULL) {
+ fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
+ llama_free_model(model);
+ return std::make_tuple(nullptr, nullptr);
+ }
+
+ if (!params.lora_adapter.empty()) {
+ int err = llama_model_apply_lora_from_file(model,
+ params.lora_adapter.c_str(),
+ params.lora_base.empty() ? NULL : params.lora_base.c_str(),
+ params.n_threads);
+ if (err != 0) {
+ fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
+ llama_free(lctx);
+ llama_free_model(model);
+ return std::make_tuple(nullptr, nullptr);
+ }
+ }
+
+ if (params.ignore_eos) {
+ params.logit_bias[llama_token_eos(lctx)] = -INFINITY;
+ }
+
+ return std::make_tuple(model, lctx);
+}
+
+//
+// Vocab utils
+//
+
+std::vector<llama_token> llama_tokenize(
+ struct llama_context * ctx,
+ const std::string & text,
+ bool add_bos) {
+ // upper limit for the number of tokens
+ int n_tokens = text.length() + add_bos;
+ std::vector<llama_token> result(n_tokens);
+ n_tokens = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
+ if (n_tokens < 0) {
+ result.resize(-n_tokens);
+ int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos);
+ GGML_ASSERT(check == -n_tokens);
+ } else {
+ result.resize(n_tokens);
+ }
+ return result;
+}
+
+std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) {
+ std::vector<char> result(8, 0);
+ const int n_tokens = llama_token_to_str(ctx, token, result.data(), result.size());
+ if (n_tokens < 0) {
+ result.resize(-n_tokens);
+ int check = llama_token_to_str(ctx, token, result.data(), result.size());
+ GGML_ASSERT(check == -n_tokens);
+ } else {
+ result.resize(n_tokens);
+ }
+
+ return std::string(result.data(), result.size());
+}
+
+std::vector<llama_token> llama_tokenize_bpe(
+ struct llama_context * ctx,
+ const std::string & text,
+ bool add_bos) {
+ int n_tokens = text.length() + add_bos;
+ std::vector<llama_token> result(n_tokens);
+ n_tokens = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
+ if (n_tokens < 0) {
+ result.resize(-n_tokens);
+ int check = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
+ GGML_ASSERT(check == -n_tokens);
+ } else {
+ result.resize(n_tokens);
+ }
+ return result;
+}
+
+std::string llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token) {
+ std::vector<char> result(8, 0);
+ const int n_tokens = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
+ if (n_tokens < 0) {
+ result.resize(-n_tokens);
+ const int check = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
+ GGML_ASSERT(check == -n_tokens);
+ } else {
+ result.resize(n_tokens);
+ }
+
+ return std::string(result.data(), result.size());
+}
+
diff --git a/common/common.h b/common/common.h
new file mode 100644
index 00000000..c50a6edf
--- /dev/null
+++ b/common/common.h
@@ -0,0 +1,130 @@
+// Various helper functions and utilities
+
+#pragma once
+
+#include "llama.h"
+
+#include <string>
+#include <vector>
+#include <random>
+#include <thread>
+#include <unordered_map>
+#include <tuple>
+
+//
+// CLI argument parsing
+//
+int32_t get_num_physical_cores();
+
+struct gpt_params {
+ uint32_t seed = -1; // RNG seed
+ int32_t n_threads = get_num_physical_cores();
+ int32_t n_predict = -1; // new tokens to predict
+ int32_t n_ctx = 512; // context size
+ int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
+ int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
+ int32_t n_gpu_layers = 0; // number of layers to store in VRAM
+ int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
+ float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
+ int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
+ float rope_freq_base = 10000.0f; // RoPE base frequency
+ float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
+
+ // sampling parameters
+ int32_t top_k = 40; // <= 0 to use vocab size
+ float top_p = 0.95f; // 1.0 = disabled
+ float tfs_z = 1.00f; // 1.0 = disabled
+ float typical_p = 1.00f; // 1.0 = disabled
+ float temp = 0.80f; // 1.0 = disabled
+ float repeat_penalty = 1.10f; // 1.0 = disabled
+ int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
+ float frequency_penalty = 0.00f; // 0.0 = disabled
+ float presence_penalty = 0.00f; // 0.0 = disabled
+ int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+ float mirostat_tau = 5.00f; // target entropy
+ float mirostat_eta = 0.10f; // learning rate
+
+ std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
+
+ // Classifier-Free Guidance
+ // https://arxiv.org/abs/2306.17806
+ std::string cfg_negative_prompt; // string to help guidance
+ float cfg_scale = 1.f; // How strong is guidance
+
+ std::string model = "models/7B/ggml-model-f16.gguf"; // model path
+ std::string model_alias = "unknown"; // model alias
+ std::string prompt = "";
+ std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
+ std::string input_prefix = ""; // string to prefix user inputs with
+ std::string input_suffix = ""; // string to suffix user inputs with
+ std::string grammar = ""; // optional BNF-like grammar to constrain sampling
+ std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
+
+ std::string lora_adapter = ""; // lora adapter path
+ std::string lora_base = ""; // base model path for the lora adapter
+
+ bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
+ size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
+
+ bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
+ bool mul_mat_q = false; // if true, use experimental mul_mat_q kernels
+ bool memory_f16 = true; // use f16 instead of f32 for memory kv
+ bool random_prompt = false; // do not randomize prompt if none provided
+ bool use_color = false; // use color to distinguish generations and inputs
+ bool interactive = false; // interactive mode
+ bool prompt_cache_all = false; // save user input and generations to prompt cache
+ bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
+
+ bool embedding = false; // get only sentence embedding
+ bool interactive_first = false; // wait for user input immediately
+ bool multiline_input = false; // reverse the usage of `\`
+ bool simple_io = false; // improves compatibility with subprocesses and limited consoles
+
+ bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
+ bool ignore_eos = false; // ignore generated EOS tokens
+ bool instruct = false; // instruction mode (used for Alpaca models)
+ bool penalize_nl = true; // consider newlines as a repeatable token
+ bool perplexity = false; // compute perplexity over the prompt
+ bool use_mmap = true; // use mmap for faster loads
+ bool use_mlock = false; // use mlock to keep model in memory
+ bool mem_test = false; // compute maximum memory usage
+ bool numa = false; // attempt optimizations that help on some NUMA systems
+ bool export_cgraph = false; // export the computation graph
+ bool verbose_prompt = false; // print prompt tokens before generation
+};
+
+bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
+
+void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
+
+std::string gpt_random_prompt(std::mt19937 & rng);
+
+//
+// Model utils
+//
+
+std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
+struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
+
+//
+// Vocab utils
+//
+
+std::vector<llama_token> llama_tokenize(
+ struct llama_context * ctx,
+ const std::string & text,
+ bool add_bos);
+
+std::vector<llama_token> llama_tokenize_bpe(
+ struct llama_context * ctx,
+ const std::string & text,
+ bool add_bos);
+
+std::string llama_token_to_str(
+ const struct llama_context * ctx,
+ llama_token token);
+
+std::string llama_token_to_str_bpe(
+ const struct llama_context * ctx,
+ llama_token token);
diff --git a/common/console.cpp b/common/console.cpp
new file mode 100644
index 00000000..8efa2a67
--- /dev/null
+++ b/common/console.cpp
@@ -0,0 +1,500 @@
+#include "console.h"
+#include <vector>
+#include <iostream>
+
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif
+#include <windows.h>
+#include <fcntl.h>
+#include <io.h>
+#ifndef ENABLE_VIRTUAL_TERMINAL_PROCESSING
+#define ENABLE_VIRTUAL_TERMINAL_PROCESSING 0x0004
+#endif
+#else
+#include <climits>
+#include <sys/ioctl.h>
+#include <unistd.h>
+#include <wchar.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <signal.h>
+#include <termios.h>
+#endif
+
+#define ANSI_COLOR_RED "\x1b[31m"
+#define ANSI_COLOR_GREEN "\x1b[32m"
+#define ANSI_COLOR_YELLOW "\x1b[33m"
+#define ANSI_COLOR_BLUE "\x1b[34m"
+#define ANSI_COLOR_MAGENTA "\x1b[35m"
+#define ANSI_COLOR_CYAN "\x1b[36m"
+#define ANSI_COLOR_RESET "\x1b[0m"
+#define ANSI_BOLD "\x1b[1m"
+
+namespace console {
+
+ //
+ // Console state
+ //
+
+ static bool advanced_display = false;
+ static bool simple_io = true;
+ static display_t current_display = reset;
+
+ static FILE* out = stdout;
+
+#if defined (_WIN32)
+ static void* hConsole;
+#else
+ static FILE* tty = nullptr;
+ static termios initial_state;
+#endif
+
+ //
+ // Init and cleanup
+ //
+
+ void init(bool use_simple_io, bool use_advanced_display) {
+ advanced_display = use_advanced_display;
+ simple_io = use_simple_io;
+#if defined(_WIN32)
+ // Windows-specific console initialization
+ DWORD dwMode = 0;
+ hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
+ if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) {
+ hConsole = GetStdHandle(STD_ERROR_HANDLE);
+ if (hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(hConsole, &dwMode))) {
+ hConsole = nullptr;
+ simple_io = true;
+ }
+ }
+ if (hConsole) {
+ // Check conditions combined to reduce nesting
+ if (advanced_display && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING) &&
+ !SetConsoleMode(hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)) {
+ advanced_display = false;
+ }
+ // Set console output codepage to UTF8
+ SetConsoleOutputCP(CP_UTF8);
+ }
+ HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE);
+ if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) {
+ // Set console input codepage to UTF16
+ _setmode(_fileno(stdin), _O_WTEXT);
+
+ // Set ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
+ if (simple_io) {
+ dwMode |= ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT;
+ } else {
+ dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
+ }
+ if (!SetConsoleMode(hConIn, dwMode)) {
+ simple_io = true;
+ }
+ }
+#else
+ // POSIX-specific console initialization
+ if (!simple_io) {
+ struct termios new_termios;
+ tcgetattr(STDIN_FILENO, &initial_state);
+ new_termios = initial_state;
+ new_termios.c_lflag &= ~(ICANON | ECHO);
+ new_termios.c_cc[VMIN] = 1;
+ new_termios.c_cc[VTIME] = 0;
+ tcsetattr(STDIN_FILENO, TCSANOW, &new_termios);
+
+ tty = fopen("/dev/tty", "w+");
+ if (tty != nullptr) {
+ out = tty;
+ }
+ }
+
+ setlocale(LC_ALL, "");
+#endif
+ }
+
+ void cleanup() {
+ // Reset console display
+ set_display(reset);
+
+#if !defined(_WIN32)
+ // Restore settings on POSIX systems
+ if (!simple_io) {
+ if (tty != nullptr) {
+ out = stdout;
+ fclose(tty);
+ tty = nullptr;
+ }
+ tcsetattr(STDIN_FILENO, TCSANOW, &initial_state);
+ }
+#endif
+ }
+
+ //
+ // Display and IO
+ //
+
+ // Keep track of current display and only emit ANSI code if it changes
+ void set_display(display_t display) {
+ if (advanced_display && current_display != display) {
+ fflush(stdout);
+ switch(display) {
+ case reset:
+ fprintf(out, ANSI_COLOR_RESET);
+ break;
+ case prompt:
+ fprintf(out, ANSI_COLOR_YELLOW);
+ break;
+ case user_input:
+ fprintf(out, ANSI_BOLD ANSI_COLOR_GREEN);
+ break;
+ case error:
+ fprintf(out, ANSI_BOLD ANSI_COLOR_RED);
+ }
+ current_display = display;
+ fflush(out);
+ }
+ }
+
+ char32_t getchar32() {
+#if defined(_WIN32)
+ HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE);
+ wchar_t high_surrogate = 0;
+
+ while (true) {
+ INPUT_RECORD record;
+ DWORD count;
+ if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) {
+ return WEOF;
+ }
+
+ if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) {
+ wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar;
+ if (wc == 0) {
+ continue;
+ }
+
+ if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate
+ high_surrogate = wc;
+ continue;
+ }
+ if ((wc >= 0xDC00) && (wc <= 0xDFFF)) { // Check if wc is a low surrogate
+ if (high_surrogate != 0) { // Check if we have a high surrogate
+ return ((high_surrogate - 0xD800) << 10) + (wc - 0xDC00) + 0x10000;
+ }
+ }
+
+ high_surrogate = 0; // Reset the high surrogate
+ return static_cast<char32_t>(wc);
+ }
+ }
+#else
+ wchar_t wc = getwchar();
+ if (static_cast<wint_t>(wc) == WEOF) {
+ return WEOF;
+ }
+
+#if WCHAR_MAX == 0xFFFF
+ if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate
+ wchar_t low_surrogate = getwchar();
+ if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate
+ return (static_cast<char32_t>(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000;
+ }
+ }
+ if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair
+ return 0xFFFD; // Return the replacement character U+FFFD
+ }
+#endif
+
+ return static_cast<char32_t>(wc);
+#endif
+ }
+
+ void pop_cursor() {
+#if defined(_WIN32)
+ if (hConsole != NULL) {
+ CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
+ GetConsoleScreenBufferInfo(hConsole, &bufferInfo);
+
+ COORD newCursorPosition = bufferInfo.dwCursorPosition;
+ if (newCursorPosition.X == 0) {
+ newCursorPosition.X = bufferInfo.dwSize.X - 1;
+ newCursorPosition.Y -= 1;
+ } else {
+ newCursorPosition.X -= 1;
+ }
+
+ SetConsoleCursorPosition(hConsole, newCursorPosition);
+ return;
+ }
+#endif
+ putc('\b', out);
+ }
+
+ int estimateWidth(char32_t codepoint) {
+#if defined(_WIN32)
+ return 1;
+#else
+ return wcwidth(codepoint);
+#endif
+ }
+
+ int put_codepoint(const char* utf8_codepoint, size_t length, int expectedWidth) {
+#if defined(_WIN32)
+ CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
+ if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) {
+ // go with the default
+ return expectedWidth;
+ }
+ COORD initialPosition = bufferInfo.dwCursorPosition;
+ DWORD nNumberOfChars = length;
+ WriteConsole(hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL);
+
+ CONSOLE_SCREEN_BUFFER_INFO newBufferInfo;
+ GetConsoleScreenBufferInfo(hConsole, &newBufferInfo);
+
+ // Figure out our real position if we're in the last column
+ if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) {
+ DWORD nNumberOfChars;
+ WriteConsole(hConsole, &" \b", 2, &nNumberOfChars, NULL);
+ GetConsoleScreenBufferInfo(hConsole, &newBufferInfo);
+ }
+
+ int width = newBufferInfo.dwCursorPosition.X - initialPosition.X;
+ if (width < 0) {
+ width += newBufferInfo.dwSize.X;
+ }
+ return width;
+#else
+ // We can trust expectedWidth if we've got one
+ if (expectedWidth >= 0 || tty == nullptr) {
+ fwrite(utf8_codepoint, length, 1, out);
+ return expectedWidth;
+ }
+
+ fputs("\033[6n", tty); // Query cursor position
+ int x1;
+ int y1;
+ int x2;
+ int y2;
+ int results = 0;
+ results = fscanf(tty, "\033[%d;%dR", &y1, &x1);
+
+ fwrite(utf8_codepoint, length, 1, tty);
+
+ fputs("\033[6n", tty); // Query cursor position
+ results += fscanf(tty, "\033[%d;%dR", &y2, &x2);
+
+ if (results != 4) {
+ return expectedWidth;
+ }
+
+ int width = x2 - x1;
+ if (width < 0) {
+ // Calculate the width considering text wrapping
+ struct winsize w;
+ ioctl(STDOUT_FILENO, TIOCGWINSZ, &w);
+ width += w.ws_col;
+ }
+ return width;
+#endif
+ }
+
+ void replace_last(char ch) {
+#if defined(_WIN32)
+ pop_cursor();
+ put_codepoint(&ch, 1, 1);
+#else
+ fprintf(out, "\b%c", ch);
+#endif
+ }
+
+ void append_utf8(char32_t ch, std::string & out) {
+ if (ch <= 0x7F) {
+ out.push_back(static_cast<unsigned char>(ch));
+ } else if (ch <= 0x7FF) {
+ out.push_back(static_cast<unsigned char>(0xC0 | ((ch >> 6) & 0x1F)));
+ out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
+ } else if (ch <= 0xFFFF) {
+ out.push_back(static_cast<unsigned char>(0xE0 | ((ch >> 12) & 0x0F)));
+ out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
+ out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
+ } else if (ch <= 0x10FFFF) {
+ out.push_back(static_cast<unsigned char>(0xF0 | ((ch >> 18) & 0x07)));
+ out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 12) & 0x3F)));
+ out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
+ out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
+ } else {
+ // Invalid Unicode code point
+ }
+ }
+
+ // Helper function to remove the last UTF-8 character from a string
+ void pop_back_utf8_char(std::string & line) {
+ if (line.empty()) {
+ return;
+ }
+
+ size_t pos = line.length() - 1;
+
+ // Find the start of the last UTF-8 character (checking up to 4 bytes back)
+ for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) {
+ if ((line[pos] & 0xC0) != 0x80) {
+ break; // Found the start of the character
+ }
+ }
+ line.erase(pos);
+ }
+
+ bool readline_advanced(std::string & line, bool multiline_input) {
+ if (out != stdout) {
+ fflush(stdout);
+ }
+
+ line.clear();
+ std::vector<int> widths;
+ bool is_special_char = false;
+ bool end_of_stream = false;
+
+ char32_t input_char;
+ while (true) {
+ fflush(out); // Ensure all output is displayed before waiting for input
+ input_char = getchar32();
+
+ if (input_char == '\r' || input_char == '\n') {
+ break;
+ }
+
+ if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) {
+ end_of_stream = true;
+ break;
+ }
+
+ if (is_special_char) {
+ set_display(user_input);
+ replace_last(line.back());
+ is_special_char = false;
+ }
+
+ if (input_char == '\033') { // Escape sequence
+ char32_t code = getchar32();
+ if (code == '[' || code == 0x1B) {
+ // Discard the rest of the escape sequence
+ while ((code = getchar32()) != (char32_t) WEOF) {
+ if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') {
+ break;
+ }
+ }
+ }
+ } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace
+ if (!widths.empty()) {
+ int count;
+ do {
+ count = widths.back();
+ widths.pop_back();
+ // Move cursor back, print space, and move cursor back again
+ for (int i = 0; i < count; i++) {
+ replace_last(' ');
+ pop_cursor();
+ }
+ pop_back_utf8_char(line);
+ } while (count == 0 && !widths.empty());
+ }
+ } else {
+ int offset = line.length();
+ append_utf8(input_char, line);
+ int width = put_codepoint(line.c_str() + offset, line.length() - offset, estimateWidth(input_char));
+ if (width < 0) {
+ width = 0;
+ }
+ widths.push_back(width);
+ }
+
+ if (!line.empty() && (line.back() == '\\' || line.back() == '/')) {
+ set_display(prompt);
+ replace_last(line.back());
+ is_special_char = true;
+ }
+ }
+
+ bool has_more = multiline_input;
+ if (is_special_char) {
+ replace_last(' ');
+ pop_cursor();
+
+ char last = line.back();
+ line.pop_back();
+ if (last == '\\') {
+ line += '\n';
+ fputc('\n', out);
+ has_more = !has_more;
+ } else {
+ // llama will just eat the single space, it won't act as a space
+ if (line.length() == 1 && line.back() == ' ') {
+ line.clear();
+ pop_cursor();
+ }
+ has_more = false;
+ }
+ } else {
+ if (end_of_stream) {
+ has_more = false;
+ } else {
+ line += '\n';
+ fputc('\n', out);
+ }
+ }
+
+ fflush(out);
+ return has_more;
+ }
+
+ bool readline_simple(std::string & line, bool multiline_input) {
+#if defined(_WIN32)
+ std::wstring wline;
+ if (!std::getline(std::wcin, wline)) {
+ // Input stream is bad or EOF received
+ line.clear();
+ GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0);
+ return false;
+ }
+
+ int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
+ line.resize(size_needed);
+ WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
+#else
+ if (!std::getline(std::cin, line)) {
+ // Input stream is bad or EOF received
+ line.clear();
+ return false;
+ }
+#endif
+ if (!line.empty()) {
+ char last = line.back();
+ if (last == '/') { // Always return control on '/' symbol
+ line.pop_back();
+ return false;
+ }
+ if (last == '\\') { // '\\' changes the default action
+ line.pop_back();
+ multiline_input = !multiline_input;
+ }
+ }
+ line += '\n';
+
+ // By default, continue input if multiline_input is set
+ return multiline_input;
+ }
+
+ bool readline(std::string & line, bool multiline_input) {
+ set_display(user_input);
+
+ if (simple_io) {
+ return readline_simple(line, multiline_input);
+ }
+ return readline_advanced(line, multiline_input);
+ }
+
+}
diff --git a/common/console.h b/common/console.h
new file mode 100644
index 00000000..ec175269
--- /dev/null
+++ b/common/console.h
@@ -0,0 +1,19 @@
+// Console functions
+
+#pragma once
+
+#include <string>
+
+namespace console {
+ enum display_t {
+ reset = 0,
+ prompt,
+ user_input,
+ error
+ };
+
+ void init(bool use_simple_io, bool use_advanced_display);
+ void cleanup();
+ void set_display(display_t display);
+ bool readline(std::string & line, bool multiline_input);
+}
diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp
new file mode 100644
index 00000000..e76bd11c
--- /dev/null
+++ b/common/grammar-parser.cpp
@@ -0,0 +1,423 @@
+#include "grammar-parser.h"
+#include <cstdint>
+#include <cwchar>
+#include <string>
+#include <utility>
+#include <stdexcept>
+#include <exception>
+
+namespace grammar_parser {
+ // NOTE: assumes valid utf8 (but checks for overrun)
+ // copied from llama.cpp
+ std::pair<uint32_t, const char *> decode_utf8(const char * src) {
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+ uint8_t first_byte = static_cast<uint8_t>(*src);
+ uint8_t highbits = first_byte >> 4;
+ int len = lookup[highbits];
+ uint8_t mask = (1 << (8 - len)) - 1;
+ uint32_t value = first_byte & mask;
+ const char * end = src + len; // may overrun!
+ const char * pos = src + 1;
+ for ( ; pos < end && *pos; pos++) {
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+ }
+ return std::make_pair(value, pos);
+ }
+
+ uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
+ uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
+ auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
+ return result.first->second;
+ }
+
+ uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
+ uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
+ state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
+ return next_id;
+ }
+
+ void add_rule(
+ parse_state & state,
+ uint32_t rule_id,
+ const std::vector<llama_grammar_element> & rule) {
+ if (state.rules.size() <= rule_id) {
+ state.rules.resize(rule_id + 1);
+ }
+ state.rules[rule_id] = rule;
+ }
+
+ bool is_word_char(char c) {
+ return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
+ }
+
+ std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
+ const char * pos = src;
+ const char * end = src + size;
+ uint32_t value = 0;
+ for ( ; pos < end && *pos; pos++) {
+ value <<= 4;
+ char c = *pos;
+ if ('a' <= c && c <= 'f') {
+ value += c - 'a' + 10;
+ } else if ('A' <= c && c <= 'F') {
+ value += c - 'A' + 10;
+ } else if ('0' <= c && c <= '9') {
+ value += c - '0';
+ } else {
+ break;
+ }
+ }
+ if (pos != end) {
+ throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
+ }
+ return std::make_pair(value, pos);
+ }
+
+ const char * parse_space(const char * src, bool newline_ok) {
+ const char * pos = src;
+ while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
+ (newline_ok && (*pos == '\r' || *pos == '\n'))) {
+ if (*pos == '#') {
+ while (*pos && *pos != '\r' && *pos != '\n') {
+ pos++;
+ }
+ } else {
+ pos++;
+ }
+ }
+ return pos;
+ }
+
+ const char * parse_name(const char * src) {
+ const char * pos = src;
+ while (is_word_char(*pos)) {
+ pos++;
+ }
+ if (pos == src) {
+ throw std::runtime_error(std::string("expecting name at ") + src);
+ }
+ return pos;
+ }
+
+ std::pair<uint32_t, const char *> parse_char(const char * src) {
+ if (*src == '\\') {
+ switch (src[1]) {
+ case 'x': return parse_hex(src + 2, 2);
+ case 'u': return parse_hex(src + 2, 4);
+ case 'U': return parse_hex(src + 2, 8);
+ case 't': return std::make_pair('\t', src + 2);
+ case 'r': return std::make_pair('\r', src + 2);
+ case 'n': return std::make_pair('\n', src + 2);
+ case '\\':
+ case '"':
+ case '[':
+ case ']':
+ return std::make_pair(src[1], src + 2);
+ default:
+ throw std::runtime_error(std::string("unknown escape at ") + src);
+ }
+ } else if (*src) {
+ return decode_utf8(src);
+ }
+ throw std::runtime_error("unexpected end of input");
+ }
+
+ const char * parse_alternates(
+ parse_state & state,
+ const char * src,
+ const std::string & rule_name,
+ uint32_t rule_id,
+ bool is_nested);
+
+ const char * parse_sequence(
+ parse_state & state,
+ const char * src,
+ const std::string & rule_name,
+ std::vector<llama_grammar_element> & out_elements,
+ bool is_nested) {
+ size_t last_sym_start = out_elements.size();
+ const char * pos = src;
+ while (*pos) {
+ if (*pos == '"') { // literal string
+ pos++;
+ last_sym_start = out_elements.size();
+ while (*pos != '"') {
+ auto char_pair = parse_char(pos);
+ pos = char_pair.second;
+ out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else if (*pos == '[') { // char range(s)
+ pos++;
+ enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
+ if (*pos == '^') {
+ pos++;
+ start_type = LLAMA_GRETYPE_CHAR_NOT;
+ }
+ last_sym_start = out_elements.size();
+ while (*pos != ']') {
+ auto char_pair = parse_char(pos);
+ pos = char_pair.second;
+ enum llama_gretype type = last_sym_start < out_elements.size()
+ ? LLAMA_GRETYPE_CHAR_ALT
+ : start_type;
+
+ out_elements.push_back({type, char_pair.first});
+ if (pos[0] == '-' && pos[1] != ']') {
+ auto endchar_pair = parse_char(pos + 1);
+ pos = endchar_pair.second;
+ out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
+ }
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else if (is_word_char(*pos)) { // rule reference
+ const char * name_end = parse_name(pos);
+ uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
+ pos = parse_space(name_end, is_nested);
+ last_sym_start = out_elements.size();
+ out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
+ } else if (*pos == '(') { // grouping
+ // parse nested alternates into synthesized rule
+ pos = parse_space(pos + 1, true);
+ uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
+ pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
+ last_sym_start = out_elements.size();
+ // output reference to synthesized rule
+ out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
+ if (*pos != ')') {
+ throw std::runtime_error(std::string("expecting ')' at ") + pos);
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
+ if (last_sym_start == out_elements.size()) {
+ throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
+ }
+
+ // apply transformation to previous symbol (last_sym_start to end) according to
+ // rewrite rules:
+ // S* --> S' ::= S S' |
+ // S+ --> S' ::= S S' | S
+ // S? --> S' ::= S |
+ uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
+ std::vector<llama_grammar_element> sub_rule;
+ // add preceding symbol to generated rule
+ sub_rule.insert(
+ sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
+ if (*pos == '*' || *pos == '+') {
+ // cause generated rule to recurse
+ sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
+ }
+ // mark start of alternate def
+ sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
+ if (*pos == '+') {
+ // add preceding symbol as alternate only for '+' (otherwise empty)
+ sub_rule.insert(
+ sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
+ }
+ sub_rule.push_back({LLAMA_GRETYPE_END, 0});
+ add_rule(state, sub_rule_id, sub_rule);
+
+ // in original rule, replace previous symbol with reference to generated rule
+ out_elements.resize(last_sym_start);
+ out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
+
+ pos = parse_space(pos + 1, is_nested);
+ } else {
+ break;
+ }
+ }
+ return pos;
+ }
+
+ const char * parse_alternates(
+ parse_state & state,
+ const char * src,
+ const std::string & rule_name,
+ uint32_t rule_id,
+ bool is_nested) {
+ std::vector<llama_grammar_element> rule;
+ const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
+ while (*pos == '|') {
+ rule.push_back({LLAMA_GRETYPE_ALT, 0});
+ pos = parse_space(pos + 1, true);
+ pos = parse_sequence(state, pos, rule_name, rule, is_nested);
+ }
+ rule.push_back({LLAMA_GRETYPE_END, 0});
+ add_rule(state, rule_id, rule);
+ return pos;
+ }
+
+ const char * parse_rule(parse_state & state, const char * src) {
+ const char * name_end = parse_name(src);
+ const char * pos = parse_space(name_end, false);
+ size_t name_len = name_end - src;
+ uint32_t rule_id = get_symbol_id(state, src, name_len);
+ const std::string name(src, name_len);
+
+ if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
+ throw std::runtime_error(std::string("expecting ::= at ") + pos);
+ }
+ pos = parse_space(pos + 3, true);
+
+ pos = parse_alternates(state, pos, name, rule_id, false);
+
+ if (*pos == '\r') {
+ pos += pos[1] == '\n' ? 2 : 1;
+ } else if (*pos == '\n') {
+ pos++;
+ } else if (*pos) {
+ throw std::runtime_error(std::string("expecting newline or end at ") + pos);
+ }
+ return parse_space(pos, true);
+ }
+
+ parse_state parse(const char * src) {
+ try {
+ parse_state state;
+ const char * pos = parse_space(src, true);
+ while (*pos) {
+ pos = parse_rule(state, pos);
+ }
+ return state;
+ } catch (const std::exception & err) {
+ fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
+ return parse_state();
+ }
+ }
+
+ void print_grammar_char(FILE * file, uint32_t c) {
+ if (0x20 <= c && c <= 0x7f) {
+ fprintf(file, "%c", static_cast<char>(c));
+ } else {
+ // cop out of encoding UTF-8
+ fprintf(file, "<U+%04X>", c);
+ }
+ }
+
+ bool is_char_element(llama_grammar_element elem) {
+ switch (elem.type) {
+ case LLAMA_GRETYPE_CHAR: return true;
+ case LLAMA_GRETYPE_CHAR_NOT: return true;
+ case LLAMA_GRETYPE_CHAR_ALT: return true;
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
+ default: return false;
+ }
+ }
+
+ void print_rule_binary(FILE * file, const std::vector<llama_grammar_element> & rule) {
+ for (auto elem : rule) {
+ switch (elem.type) {
+ case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
+ case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
+ case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
+ case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
+ case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
+ case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
+ }
+ switch (elem.type) {
+ case LLAMA_GRETYPE_END:
+ case LLAMA_GRETYPE_ALT:
+ case LLAMA_GRETYPE_RULE_REF:
+ fprintf(file, "(%u) ", elem.value);
+ break;
+ case LLAMA_GRETYPE_CHAR:
+ case LLAMA_GRETYPE_CHAR_NOT:
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+ case LLAMA_GRETYPE_CHAR_ALT:
+ fprintf(file, "(\"");
+ print_grammar_char(file, elem.value);
+ fprintf(file, "\") ");
+ break;
+ }
+ }
+ fprintf(file, "\n");
+ }
+
+ void print_rule(
+ FILE * file,
+ uint32_t rule_id,
+ const std::vector<llama_grammar_element> & rule,
+ const std::map<uint32_t, std::string> & symbol_id_names) {
+ if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
+ throw std::runtime_error(
+ "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
+ }
+ fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
+ for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
+ llama_grammar_element elem = rule[i];
+ switch (elem.type) {
+ case LLAMA_GRETYPE_END:
+ throw std::runtime_error(
+ "unexpected end of rule: " + std::to_string(rule_id) + "," +
+ std::to_string(i));
+ case LLAMA_GRETYPE_ALT:
+ fprintf(file, "| ");
+ break;
+ case LLAMA_GRETYPE_RULE_REF:
+ fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
+ break;
+ case LLAMA_GRETYPE_CHAR:
+ fprintf(file, "[");
+ print_grammar_char(file, elem.value);
+ break;
+ case LLAMA_GRETYPE_CHAR_NOT:
+ fprintf(file, "[^");
+ print_grammar_char(file, elem.value);
+ break;
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+ if (i == 0 || !is_char_element(rule[i - 1])) {
+ throw std::runtime_error(
+ "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
+ std::to_string(rule_id) + "," + std::to_string(i));
+ }
+ fprintf(file, "-");
+ print_grammar_char(file, elem.value);
+ break;
+ case LLAMA_GRETYPE_CHAR_ALT:
+ if (i == 0 || !is_char_element(rule[i - 1])) {
+ throw std::runtime_error(
+ "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
+ std::to_string(rule_id) + "," + std::to_string(i));
+ }
+ print_grammar_char(file, elem.value);
+ break;
+ }
+ if (is_char_element(elem)) {
+ switch (rule[i + 1].type) {
+ case LLAMA_GRETYPE_CHAR_ALT:
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+ break;
+ default:
+ fprintf(file, "] ");
+ }
+ }
+ }
+ fprintf(file, "\n");
+ }
+
+ void print_grammar(FILE * file, const parse_state & state) {
+ try {
+ std::map<uint32_t, std::string> symbol_id_names;
+ for (auto kv : state.symbol_ids) {
+ symbol_id_names[kv.second] = kv.first;
+ }
+ for (size_t i = 0, end = state.rules.size(); i < end; i++) {
+ // fprintf(file, "%zu: ", i);
+ // print_rule_binary(file, state.rules[i]);
+ print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
+ // fprintf(file, "\n");
+ }
+ } catch (const std::exception & err) {
+ fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
+ }
+ }
+
+ std::vector<const llama_grammar_element *> parse_state::c_rules() {
+ std::vector<const llama_grammar_element *> ret;
+ for (const auto & rule : rules) {
+ ret.push_back(rule.data());
+ }
+ return ret;
+ }
+}
diff --git a/common/grammar-parser.h b/common/grammar-parser.h
new file mode 100644
index 00000000..9037d727
--- /dev/null
+++ b/common/grammar-parser.h
@@ -0,0 +1,29 @@
+// Implements a parser for an extended Backus-Naur form (BNF), producing the
+// binary context-free grammar format specified by llama.h. Supports character
+// ranges, grouping, and repetition operators. As an example, a grammar for
+// arithmetic might look like:
+//
+// root ::= expr
+// expr ::= term ([-+*/] term)*
+// term ::= num | "(" space expr ")" space
+// num ::= [0-9]+ space
+// space ::= [ \t\n]*
+
+#pragma once
+#include "llama.h"
+#include <vector>
+#include <map>
+#include <cstdint>
+#include <string>
+
+namespace grammar_parser {
+ struct parse_state {
+ std::map<std::string, uint32_t> symbol_ids;
+ std::vector<std::vector<llama_grammar_element>> rules;
+
+ std::vector<const llama_grammar_element *> c_rules();
+ };
+
+ parse_state parse(const char * src);
+ void print_grammar(FILE * file, const parse_state & state);
+}