summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKerfuffle <44031344+KerfuffleV2@users.noreply.github.com>2023-10-11 13:35:46 -0600
committerGitHub <noreply@github.com>2023-10-11 22:35:46 +0300
commit70c29da118cdb02bfcbd0376c32b5b2236e48e48 (patch)
tree9ba08e6a18d60e24b580d58b57f9c2b7a8848f3d
parent8c70a5ff25964f0a81e20d142a2f5ac5baff22fc (diff)
common : fix mirostat state when using multiple sequences (#3543)
* Fix mirostat state when using multiple sequences * Fix mirostat by completely refactoring sampling! * Try to fix zig build. * Export function to fetch/create default sampler states Code formatting cleanups and add some comments Silence a warning about id not being used when logging is disabled * Apply some renaming suggestions. Fix comments that were out of sync with the pull. * Use more consistant naming convention for sampling contexts
-rw-r--r--Makefile86
-rw-r--r--build.zig5
-rw-r--r--common/CMakeLists.txt2
-rw-r--r--common/common.cpp228
-rw-r--r--common/common.h56
-rw-r--r--common/sampling.cpp166
-rw-r--r--common/sampling.h108
-rw-r--r--examples/embd-input/embd-input-lib.cpp19
-rw-r--r--examples/infill/infill.cpp18
-rw-r--r--examples/main/main.cpp18
-rw-r--r--examples/parallel/parallel.cpp6
-rw-r--r--examples/save-load-state/save-load-state.cpp5
-rw-r--r--examples/server/server.cpp100
-rw-r--r--examples/speculative/speculative.cpp12
14 files changed, 495 insertions, 334 deletions
diff --git a/Makefile b/Makefile
index 571ad3bb..705fa1ef 100644
--- a/Makefile
+++ b/Makefile
@@ -178,6 +178,24 @@ else
MK_CPPFLAGS += -DNDEBUG
endif
+ifdef LLAMA_SANITIZE_THREAD
+ MK_CFLAGS += -fsanitize=thread -g
+ MK_CXXFLAGS += -fsanitize=thread -g
+ MK_LDFLAGS += -fsanitize=thread -g
+endif
+
+ifdef LLAMA_SANITIZE_ADDRESS
+ MK_CFLAGS += -fsanitize=address -fno-omit-frame-pointer -g
+ MK_CXXFLAGS += -fsanitize=address -fno-omit-frame-pointer -g
+ MK_LDFLAGS += -fsanitize=address -fno-omit-frame-pointer -g
+endif
+
+ifdef LLAMA_SANITIZE_UNDEFINED
+ MK_CFLAGS += -fsanitize=undefined -g
+ MK_CXXFLAGS += -fsanitize=undefined -g
+ MK_LDFLAGS += -fsanitize=undefined -g
+endif
+
ifdef LLAMA_SERVER_VERBOSE
MK_CPPFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE)
endif
@@ -526,7 +544,13 @@ OBJS += ggml-alloc.o ggml-backend.o
llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h
$(CXX) $(CXXFLAGS) -c $< -o $@
-common.o: common/common.cpp common/common.h build-info.h common/log.h
+COMMON_H_DEPS = common/common.h common/sampling.h build-info.h common/log.h
+COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o
+
+common.o: common/common.cpp $(COMMON_H_DEPS)
+ $(CXX) $(CXXFLAGS) -c $< -o $@
+
+sampling.o: common/sampling.cpp $(COMMON_H_DEPS)
$(CXX) $(CXXFLAGS) -c $< -o $@
console.o: common/console.cpp common/console.h
@@ -548,19 +572,19 @@ clean:
# Examples
#
-main: examples/main/main.cpp build-info.h ggml.o llama.o common.o console.o grammar-parser.o $(OBJS)
+main: examples/main/main.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
@echo
@echo '==== Run ./main -h for help. ===='
@echo
-infill: examples/infill/infill.cpp build-info.h ggml.o llama.o common.o console.o grammar-parser.o $(OBJS)
+infill: examples/infill/infill.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+simple: examples/simple/simple.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-batched: examples/batched/batched.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+batched: examples/batched/batched.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
batched-bench: examples/batched-bench/batched-bench.cpp build-info.h ggml.o llama.o common.o $(OBJS)
@@ -572,53 +596,53 @@ quantize: examples/quantize/quantize.cpp build-info.h ggml.
quantize-stats: examples/quantize-stats/quantize-stats.cpp build-info.h ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-perplexity: examples/perplexity/perplexity.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+perplexity: examples/perplexity/perplexity.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-embedding: examples/embedding/embedding.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+embedding: examples/embedding/embedding.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
+server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2)
-$(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+$(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) --shared $(CXXFLAGS) $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS)
-embd-input-test: $(LIB_PRE)embdinput$(DSO_EXT) examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+embd-input-test: $(LIB_PRE)embdinput$(DSO_EXT) examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %$(DSO_EXT),$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput
gguf: examples/gguf/gguf.cpp ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp ggml.o llama.o common.o train.o $(OBJS)
+train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-llama-bench: examples/llama-bench/llama-bench.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+llama-bench: examples/llama-bench/llama-bench.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o train.o $(OBJS)
+baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-finetune: examples/finetune/finetune.cpp build-info.h ggml.o llama.o common.o train.o $(OBJS)
+finetune: examples/finetune/finetune.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-export-lora: examples/export-lora/export-lora.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+export-lora: examples/export-lora/export-lora.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-speculative: examples/speculative/speculative.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
+speculative: examples/speculative/speculative.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-parallel: examples/parallel/parallel.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+parallel: examples/parallel/parallel.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
ifdef LLAMA_METAL
@@ -659,40 +683,40 @@ vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS)
q8dot: pocs/vdot/q8dot.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
-tests/test-llama-grammar: tests/test-llama-grammar.cpp build-info.h ggml.o common.o grammar-parser.o $(OBJS)
+tests/test-llama-grammar: tests/test-llama-grammar.cpp build-info.h ggml.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-tests/test-grammar-parser: tests/test-grammar-parser.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
+tests/test-grammar-parser: tests/test-grammar-parser.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-tests/test-double-float: tests/test-double-float.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+tests/test-double-float: tests/test-double-float.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-tests/test-grad0: tests/test-grad0.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+tests/test-grad0: tests/test-grad0.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-tests/test-opt: tests/test-opt.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+tests/test-opt: tests/test-opt.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-tests/test-quantize-fns: tests/test-quantize-fns.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+tests/test-quantize-fns: tests/test-quantize-fns.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-tests/test-quantize-perf: tests/test-quantize-perf.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+tests/test-quantize-perf: tests/test-quantize-perf.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-tests/test-sampling: tests/test-sampling.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+tests/test-sampling: tests/test-sampling.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-tests/test-tokenizer-0-falcon: tests/test-tokenizer-0-falcon.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+tests/test-tokenizer-0-falcon: tests/test-tokenizer-0-falcon.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-tests/test-tokenizer-0-llama: tests/test-tokenizer-0-llama.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+tests/test-tokenizer-0-llama: tests/test-tokenizer-0-llama.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-tests/test-tokenizer-1-bpe: tests/test-tokenizer-1-bpe.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+tests/test-tokenizer-1-bpe: tests/test-tokenizer-1-bpe.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
-tests/test-tokenizer-1-llama: tests/test-tokenizer-1-llama.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+tests/test-tokenizer-1-llama: tests/test-tokenizer-1-llama.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
tests/test-c.o: tests/test-c.c llama.h
diff --git a/build.zig b/build.zig
index fdc5bc08..0b74cee4 100644
--- a/build.zig
+++ b/build.zig
@@ -128,17 +128,18 @@ pub fn build(b: *std.build.Builder) !void {
const llama = make.obj("llama", "llama.cpp");
const common = make.obj("common", "common/common.cpp");
const console = make.obj("console", "common/console.cpp");
+ const sampling = make.obj("sampling", "common/sampling.cpp");
const grammar_parser = make.obj("grammar-parser", "common/grammar-parser.cpp");
const train = make.obj("train", "common/train.cpp");
- _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, console, grammar_parser });
+ _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, sampling, console, grammar_parser });
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common });
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common });
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common });
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, train });
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, train });
- const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, grammar_parser });
+ const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, sampling, grammar_parser });
if (server.target.isWindows()) {
server.linkSystemLibrary("ws2_32");
}
diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt
index 951aa834..fbb0ff09 100644
--- a/common/CMakeLists.txt
+++ b/common/CMakeLists.txt
@@ -5,6 +5,8 @@ set(TARGET common)
add_library(${TARGET} OBJECT
common.h
common.cpp
+ sampling.h
+ sampling.cpp
console.h
console.cpp
grammar-parser.h
diff --git a/common/common.cpp b/common/common.cpp
index 0f55c33a..4214e63a 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -107,6 +107,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std::string arg;
gpt_params default_params;
const std::string arg_prefix = "--";
+ llama_sampling_params & sparams = params.sampling_params;
for (int i = 1; i < argc; i++) {
arg = argv[i];
@@ -184,7 +185,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
- params.top_k = std::stoi(argv[i]);
+ sparams.top_k = std::stoi(argv[i]);
} else if (arg == "-c" || arg == "--ctx-size") {
if (++i >= argc) {
invalid_param = true;
@@ -216,73 +217,73 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
- params.top_p = std::stof(argv[i]);
+ sparams.top_p = std::stof(argv[i]);
} else if (arg == "--temp") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params.temp = std::stof(argv[i]);
+ sparams.temp = std::stof(argv[i]);
} else if (arg == "--tfs") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params.tfs_z = std::stof(argv[i]);
+ sparams.tfs_z = std::stof(argv[i]);
} else if (arg == "--typical") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params.typical_p = std::stof(argv[i]);
+ sparams.typical_p = std::stof(argv[i]);
} else if (arg == "--repeat-last-n") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params.repeat_last_n = std::stoi(argv[i]);
+ sparams.repeat_last_n = std::stoi(argv[i]);
} else if (arg == "--repeat-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params.repeat_penalty = std::stof(argv[i]);
+ sparams.repeat_penalty = std::stof(argv[i]);
} else if (arg == "--frequency-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params.frequency_penalty = std::stof(argv[i]);
+ sparams.frequency_penalty = std::stof(argv[i]);
} else if (arg == "--presence-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params.presence_penalty = std::stof(argv[i]);
+ sparams.presence_penalty = std::stof(argv[i]);
} else if (arg == "--mirostat") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params.mirostat = std::stoi(argv[i]);
+ sparams.mirostat = std::stoi(argv[i]);
} else if (arg == "--mirostat-lr") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params.mirostat_eta = std::stof(argv[i]);
+ sparams.mirostat_eta = std::stof(argv[i]);
} else if (arg == "--mirostat-ent") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params.mirostat_tau = std::stof(argv[i]);
+ sparams.mirostat_tau = std::stof(argv[i]);
} else if (arg == "--cfg-negative-prompt") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params.cfg_negative_prompt = argv[i];
+ sparams.cfg_negative_prompt = argv[i];
} else if (arg == "--cfg-negative-prompt-file") {
if (++i >= argc) {
invalid_param = true;
@@ -294,16 +295,16 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
- std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.cfg_negative_prompt));
- if (!params.cfg_negative_prompt.empty() && params.cfg_negative_prompt.back() == '\n') {
- params.cfg_negative_prompt.pop_back();
+ std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(sparams.cfg_negative_prompt));
+ if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') {
+ sparams.cfg_negative_prompt.pop_back();
}
} else if (arg == "--cfg-scale") {
if (++i >= argc) {
invalid_param = true;
break;
}
- params.cfg_scale = std::stof(argv[i]);
+ sparams.cfg_scale = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch-size") {
if (++i >= argc) {
invalid_param = true;
@@ -512,7 +513,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if (arg == "--ignore-eos") {
params.ignore_eos = true;
} else if (arg == "--no-penalize-nl") {
- params.penalize_nl = false;
+ sparams.penalize_nl = false;
} else if (arg == "-l" || arg == "--logit-bias") {
if (++i >= argc) {
invalid_param = true;
@@ -524,7 +525,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std::string value_str;
try {
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
- params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
+ sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
} else {
throw std::exception();
}
@@ -627,6 +628,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
}
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
+ const llama_sampling_params & sparams = params.sampling_params;
+
printf("usage: %s [options]\n", argv[0]);
printf("\n");
printf("options:\n");
@@ -659,19 +662,19 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
- printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
- printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
- printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
- printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p);
- printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n);
- printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
- printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
- printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
+ printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
+ printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
+ printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
+ printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
+ printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n);
+ printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty);
+ printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty);
+ printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty);
printf(" --mirostat N use Mirostat sampling.\n");
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
- printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);
- printf(" --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta);
- printf(" --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau);
+ printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
+ printf(" --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)sparams.mirostat_eta);
+ printf(" --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)sparams.mirostat_tau);
printf(" -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n");
printf(" modifies the likelihood of token appearing in the completion,\n");
printf(" i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
@@ -682,7 +685,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" negative prompt to use for guidance. (default: empty)\n");
printf(" --cfg-negative-prompt-file FNAME\n");
printf(" negative prompt file to use for guidance. (default: empty)\n");
- printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
+ printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale);
printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n");
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n");
@@ -690,7 +693,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --no-penalize-nl do not penalize newline token\n");
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
- printf(" --temp N temperature (default: %.1f)\n", (double)params.temp);
+ printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
@@ -840,7 +843,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
}
if (params.ignore_eos) {
- params.logit_bias[llama_token_eos(lctx)] = -INFINITY;
+ params.sampling_params.logit_bias[llama_token_eos(lctx)] = -INFINITY;
}
{
@@ -933,127 +936,6 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
}
//
-// Sampling utils
-//
-
-llama_token llama_sample_token(
- struct llama_context * ctx,
- struct llama_context * ctx_guidance,
- struct llama_grammar * grammar,
- const struct gpt_params & params,
- const std::vector<llama_token> & last_tokens,
- std::vector<llama_token_data> & candidates,
- int idx) {
- const int n_ctx = llama_n_ctx(ctx);
- const int n_vocab = llama_n_vocab(llama_get_model(ctx));
-
- const float temp = params.temp;
- const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
- const float top_p = params.top_p;
- const float tfs_z = params.tfs_z;
- const float typical_p = params.typical_p;
- const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
- const float repeat_penalty = params.repeat_penalty;
- const float alpha_presence = params.presence_penalty;
- const float alpha_frequency = params.frequency_penalty;
- const int mirostat = params.mirostat;
- const float mirostat_tau = params.mirostat_tau;
- const float mirostat_eta = params.mirostat_eta;
- const bool penalize_nl = params.penalize_nl;
-
- llama_token id = 0;
-
- float * logits = llama_get_logits_ith(ctx, idx);
-
- // Apply params.logit_bias map
- for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
- logits[it->first] += it->second;
- }
-
- candidates.clear();
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
- }
-
- llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
-
- if (ctx_guidance) {
- llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
- }
-
- // apply penalties
- if (!last_tokens.empty()) {
- const float nl_logit = logits[llama_token_nl(ctx)];
- const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
-
- llama_sample_repetition_penalty(ctx, &cur_p,
- last_tokens.data() + last_tokens.size() - last_n_repeat,
- last_n_repeat, repeat_penalty);
- llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
- last_tokens.data() + last_tokens.size() - last_n_repeat,
- last_n_repeat, alpha_frequency, alpha_presence);
-
- if (!penalize_nl) {
- for (size_t idx = 0; idx < cur_p.size; idx++) {
- if (cur_p.data[idx].id == llama_token_nl(ctx)) {
- cur_p.data[idx].logit = nl_logit;
- break;
- }
- }
- }
- }
-
- if (grammar != NULL) {
- llama_sample_grammar(ctx, &cur_p, grammar);
- }
-
- if (temp <= 0) {
- // Greedy sampling
- id = llama_sample_token_greedy(ctx, &cur_p);
- } else {
- if (mirostat == 1) {
- static float mirostat_mu = 2.0f * mirostat_tau;
- const int mirostat_m = 100;
- llama_sample_temp(ctx, &cur_p, temp);
- id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
- } else if (mirostat == 2) {
- static float mirostat_mu = 2.0f * mirostat_tau;
- llama_sample_temp(ctx, &cur_p, temp);
- id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
- } else {
- // Temperature sampling
- size_t min_keep = std::max(1, params.n_probs);
- llama_sample_top_k (ctx, &cur_p, top_k, min_keep);
- llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep);
- llama_sample_typical (ctx, &cur_p, typical_p, min_keep);
- llama_sample_top_p (ctx, &cur_p, top_p, min_keep);
- llama_sample_temp(ctx, &cur_p, temp);
-
- {
- const int n_top = 10;
- LOG("top %d candidates:\n", n_top);
-
- for (int i = 0; i < n_top; i++) {
- const llama_token id = cur_p.data[i].id;
- LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
- }
- }
-
- id = llama_sample_token(ctx, &cur_p);
-
- LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
- }
- }
- // printf("`%d`", candidates_p.size);
-
- if (grammar != NULL) {
- llama_grammar_accept_token(ctx, grammar, id);
- }
-
- return id;
-}
-
-//
// YAML utils
//
@@ -1204,6 +1086,8 @@ std::string get_sortable_timestamp() {
void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
+ const llama_sampling_params & sparams = params.sampling_params;
+
fprintf(stream, "build_commit: %s\n", BUILD_COMMIT);
fprintf(stream, "build_number: %d\n", BUILD_NUMBER);
fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false");
@@ -1250,21 +1134,21 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str());
fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch);
- dump_string_yaml_multiline(stream, "cfg_negative_prompt", params.cfg_negative_prompt.c_str());
- fprintf(stream, "cfg_scale: %f # default: 1.0\n", params.cfg_scale);
+ dump_string_yaml_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str());
+ fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale);
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
- fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", params.frequency_penalty);
+ fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty);
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str());
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
- const auto logit_bias_eos = params.logit_bias.find(llama_token_eos(lctx));
- const bool ignore_eos = logit_bias_eos != params.logit_bias.end() && logit_bias_eos->second == -INFINITY;
+ const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(lctx));
+ const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
dump_string_yaml_multiline(stream, "in_prefix", params.input_prefix.c_str());
@@ -1277,7 +1161,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());
fprintf(stream, "logit_bias:\n");
- for (std::pair<llama_token, float> lb : params.logit_bias) {
+ for (std::pair<llama_token, float> lb : sparams.logit_bias) {
if (ignore_eos && lb.first == logit_bias_eos->first) {
continue;
}
@@ -1301,30 +1185,30 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
- fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", params.mirostat);
- fprintf(stream, "mirostat_ent: %f # default: 5.0\n", params.mirostat_tau);
- fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta);
+ fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
+ fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
+ fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str());
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers);
fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict);
- fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", params.n_probs);
+ fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs);
fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false");
fprintf(stream, "no_mul_mat_q: %s # default: false\n", !params.mul_mat_q ? "true" : "false");
- fprintf(stream, "no_penalize_nl: %s # default: false\n", !params.penalize_nl ? "true" : "false");
+ fprintf(stream, "no_penalize_nl: %s # default: false\n", !sparams.penalize_nl ? "true" : "false");
fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false");
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
- fprintf(stream, "presence_penalty: %f # default: 0.0\n", params.presence_penalty);
+ fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty);
dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str());
fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens);
fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false");
- fprintf(stream, "repeat_penalty: %f # default: 1.1\n", params.repeat_penalty);
+ fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty);
fprintf(stream, "reverse_prompt:\n");
for (std::string ap : params.antiprompt) {
@@ -1342,15 +1226,15 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed);
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
- fprintf(stream, "temp: %f # default: 0.8\n", params.temp);
+ fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES);
dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector);
- fprintf(stream, "tfs: %f # default: 1.0\n", params.tfs_z);
+ fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z);
fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency());
- fprintf(stream, "top_k: %d # default: 40\n", params.top_k);
- fprintf(stream, "top_p: %f # default: 0.95\n", params.top_p);
- fprintf(stream, "typical_p: %f # default: 1.0\n", params.typical_p);
+ fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
+ fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
+ fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
}
diff --git a/common/common.h b/common/common.h
index c8021527..fa115536 100644
--- a/common/common.h
+++ b/common/common.h
@@ -4,6 +4,8 @@
#include "llama.h"
+#include "sampling.h"
+
#define LOG_NO_FILE_LINE_FUNCTION
#include "log.h"
@@ -49,31 +51,12 @@ struct gpt_params {
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
- int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t n_beams = 0; // if non-zero then use beam search of given width.
float rope_freq_base = 0.0f; // RoPE base frequency
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
- // sampling parameters
- int32_t top_k = 40; // <= 0 to use vocab size
- float top_p = 0.95f; // 1.0 = disabled
- float tfs_z = 1.00f; // 1.0 = disabled
- float typical_p = 1.00f; // 1.0 = disabled
- float temp = 0.80f; // 1.0 = disabled
- float repeat_penalty = 1.10f; // 1.0 = disabled
- int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
- float frequency_penalty = 0.00f; // 0.0 = disabled
- float presence_penalty = 0.00f; // 0.0 = disabled
- int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
- float mirostat_tau = 5.00f; // target entropy
- float mirostat_eta = 0.10f; // learning rate
-
- std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
-
- // Classifier-Free Guidance
- // https://arxiv.org/abs/2306.17806
- std::string cfg_negative_prompt; // string to help guidance
- float cfg_scale = 1.f; // How strong is guidance
+ // // sampling parameters
+ struct llama_sampling_params sampling_params;
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
std::string model_draft = ""; // draft model for speculative decoding
@@ -115,7 +98,6 @@ struct gpt_params {
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool ignore_eos = false; // ignore generated EOS tokens
bool instruct = false; // instruction mode (used for Alpaca models)
- bool penalize_nl = true; // consider newlines as a repeatable token
bool logits_all = false; // return logits for all tokens in the batch
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
@@ -181,36 +163,6 @@ std::string llama_detokenize_bpe(
const std::vector<llama_token> & tokens);
//
-// Sampling utils
-//
-
-// this is a common sampling function used across the examples for convenience
-// it can serve as a starting point for implementing your own sampling function
-//
-// required:
-// - ctx: context to use for sampling
-// - params: sampling parameters
-//
-// optional:
-// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
-// - grammar: grammar to use for sampling, ignore if NULL
-// - last_tokens: needed for repetition penalty, ignore if empty
-// - idx: sample from llama_get_logits_ith(ctx, idx)
-//
-// returns:
-// - token: sampled token
-// - candidates: vector of candidate tokens
-//
-llama_token llama_sample_token(
- struct llama_context * ctx,
- struct llama_context * ctx_guidance,
- struct llama_grammar * grammar,
- const struct gpt_params & params,
- const std::vector<llama_token> & last_tokens,
- std::vector<llama_token_data> & candidates,
- int idx = 0);
-
-//
// YAML utils
//
diff --git a/common/sampling.cpp b/common/sampling.cpp
new file mode 100644
index 00000000..8ce41945
--- /dev/null
+++ b/common/sampling.cpp
@@ -0,0 +1,166 @@
+#include "sampling.h"
+
+llama_sampling_context::~llama_sampling_context() {
+ for (auto & it : sequence_contexts) {
+ if (it.second.grammar != NULL) {
+ llama_grammar_free(it.second.grammar);
+ it.second.grammar = NULL;
+ }
+ }
+}
+
+llama_sampling_context llama_sampling_context_init(
+ const struct gpt_params & params,
+ llama_grammar * grammar) {
+ llama_sampling_context result;
+
+ result.params = params.sampling_params;
+ result.grammar = grammar;
+ return result;
+}
+
+// Note: Creates the context if it doesn't exist, so this always return something.
+llama_sampler_sequence_context & llama_sampling_get_sequence_context(
+ llama_sampling_context & ctx_sampling,
+ const llama_seq_id seq) {
+ const auto it = ctx_sampling.sequence_contexts.find(seq);
+ if (it != ctx_sampling.sequence_contexts.end()) {
+ return it->second;
+ }
+ llama_sampler_sequence_context new_ctx = {
+ 2.0f * ctx_sampling.params.mirostat_tau,
+ ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL,
+ };
+ return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second;
+}
+
+bool llama_sampling_context_reset(
+ llama_sampling_context & ctx_sampling,
+ const llama_seq_id seq) {
+ const auto it = ctx_sampling.sequence_contexts.find(seq);
+ if (it == ctx_sampling.sequence_contexts.end()) return false;
+ if (it->second.grammar != NULL) {
+ llama_grammar_free(it->second.grammar);
+ it->second.grammar = NULL;
+ }
+ ctx_sampling.sequence_contexts.erase(it);
+ return true;
+}
+
+llama_token llama_sampling_sample(
+ struct llama_context * ctx,
+ struct llama_context * ctx_guidance,
+ struct llama_sampling_context & ctx_sampling,
+ const std::vector<llama_token> & last_tokens,
+ std::vector<llama_token_data> & candidates,
+ const int idx,
+ llama_seq_id seq) {
+ const int n_ctx = llama_n_ctx(ctx);
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+
+ const llama_sampling_params & params = ctx_sampling.params;
+ const float temp = params.temp;
+ const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
+ const float top_p = params.top_p;
+ const float tfs_z = params.tfs_z;
+ const float typical_p = params.typical_p;
+ const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
+ const float repeat_penalty = params.repeat_penalty;
+ const float alpha_presence = params.presence_penalty;
+ const float alpha_frequency = params.frequency_penalty;
+ const int mirostat = params.mirostat;
+ const float mirostat_tau = params.mirostat_tau;
+ const float mirostat_eta = params.mirostat_eta;
+ const bool penalize_nl = params.penalize_nl;
+
+ llama_token id = 0;
+
+ float * logits = llama_get_logits_ith(ctx, idx);
+
+ // Apply params.logit_bias map
+ for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
+ logits[it->first] += it->second;
+ }
+
+ candidates.clear();
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ }
+
+ llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
+
+ if (ctx_guidance) {
+ llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
+ }
+
+ // apply penalties
+ if (!last_tokens.empty()) {
+ const float nl_logit = logits[llama_token_nl(ctx)];
+ const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
+
+ llama_sample_repetition_penalty(ctx, &cur_p,
+ last_tokens.data() + last_tokens.size() - last_n_repeat,
+ last_n_repeat, repeat_penalty);
+ llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
+ last_tokens.data() + last_tokens.size() - last_n_repeat,
+ last_n_repeat, alpha_frequency, alpha_presence);
+
+ if (!penalize_nl) {
+ for (size_t idx = 0; idx < cur_p.size; idx++) {
+ if (cur_p.data[idx].id == llama_token_nl(ctx)) {
+ cur_p.data[idx].logit = nl_logit;
+ break;
+ }
+ }
+ }
+ }
+
+ llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq);
+
+ if (ctx_seq.grammar != NULL) {
+ llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar);
+ }
+
+ if (temp <= 0) {
+ // Greedy sampling
+ id = llama_sample_token_greedy(ctx, &cur_p);
+ } else {
+ if (mirostat == 1) {
+ const int mirostat_m = 100;
+ llama_sample_temp(ctx, &cur_p, temp);
+ id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu);
+ } else if (mirostat == 2) {
+ llama_sample_temp(ctx, &cur_p, temp);
+ id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu);
+ } else {
+ // Temperature sampling
+ size_t min_keep = std::max(1, params.n_probs);
+ llama_sample_top_k (ctx, &cur_p, top_k, min_keep);
+ llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep);
+ llama_sample_typical (ctx, &cur_p, typical_p, min_keep);
+ llama_sample_top_p (ctx, &cur_p, top_p, min_keep);
+ llama_sample_temp(ctx, &cur_p, temp);
+
+ {
+ const int n_top = 10;
+ LOG("top %d candidates:\n", n_top);
+
+ for (int i = 0; i < n_top; i++) {
+ const llama_token id = cur_p.data[i].id;
+ (void)id; // To avoid a warning that id is unused when logging is disabled.
+ LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
+ }
+ }
+
+ id = llama_sample_token(ctx, &cur_p);
+
+ LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
+ }
+ }
+
+ if (ctx_seq.grammar != NULL) {
+ llama_grammar_accept_token(ctx, ctx_seq.grammar, id);
+ }
+
+ return id;
+}
diff --git a/common/sampling.h b/common/sampling.h
new file mode 100644
index 00000000..0aab5d03
--- /dev/null
+++ b/common/sampling.h
@@ -0,0 +1,108 @@
+#pragma once
+
+#include "llama.h"
+
+#include <string>
+#include <vector>
+#include <unordered_map>
+
+// sampling parameters
+typedef struct llama_sampling_params {
+ int32_t top_k = 40; // <= 0 to use vocab size
+ float top_p = 0.95f; // 1.0 = disabled
+ float tfs_z = 1.00f; // 1.0 = disabled
+ float typical_p = 1.00f; // 1.0 = disabled
+ float temp = 0.80f; // 1.0 = disabled
+ float repeat_penalty = 1.10f; // 1.0 = disabled
+ int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
+ float frequency_penalty = 0.00f; // 0.0 = disabled
+ float presence_penalty = 0.00f; // 0.0 = disabled
+ int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+ float mirostat_tau = 5.00f; // target entropy
+ float mirostat_eta = 0.10f; // learning rate
+
+ bool penalize_nl = true; // consider newlines as a repeatable token
+
+ int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
+
+ // Classifier-Free Guidance
+ // https://arxiv.org/abs/2306.17806
+ std::string cfg_negative_prompt; // string to help guidance
+ float cfg_scale = 1.f; // How strong is guidance
+
+ std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
+
+} llama_sampling_params;
+
+// per-sequence sampler context
+typedef struct llama_sampler_sequence_context {
+ float mirostat_mu; // mirostat sampler state
+ llama_grammar * grammar;
+} llama_sampler_sequence_context;
+
+// general sampler context
+typedef struct llama_sampling_context {
+ ~llama_sampling_context();
+
+ // parameters that will be used for sampling and when creating
+ // new llama_sampler_sequence_context instances
+ llama_sampling_params params;
+
+ // map of sequence ids to sampler contexts
+ std::unordered_map<llama_seq_id, llama_sampler_sequence_context> sequence_contexts;
+
+ // when non-NULL, new instances of llama_sampler_sequence_context
+ // will get a copy of the grammar here
+ // note: only the pointer is stored here, it is not a copy of
+ // the grammar and shouldn't be freed
+ llama_grammar * grammar;
+} llama_sampling_context;
+
+#include "common.h"
+
+// Create a new sampling context instance.
+llama_sampling_context llama_sampling_context_init(
+ const struct gpt_params & params,
+ llama_grammar * grammar = NULL);
+
+// Fetches the sampler context for the specified sequence id (defaults to 0).
+// If the context for that sequence id doesn't already exist, it will be created with
+// default values based on the parameters in the ctx_sampling argument.
+llama_sampler_sequence_context & llama_sampling_get_sequence_context(
+ llama_sampling_context & ctx_sampling,
+ const llama_seq_id seq = 0);
+
+// Reset the sampler context for the supplied sequence id (defaults to 0).
+// This is necessary to reuse a sequence id or free memory used by sequences
+// that are no longer required.
+bool llama_sampling_context_reset(
+ llama_sampling_context & ctx_sampling,
+ const llama_seq_id seq = 0);
+
+// this is a common sampling function used across the examples for convenience
+// it can serve as a starting point for implementing your own sampling function
+// Note: When using multiple sequences, it is the caller's responsibility to call
+// llama_sampling_context_reset when a sequence ends
+//
+// required:
+// - ctx: context to use for sampling
+// - ctx_sampling: sampling-specific context
+//
+// optional:
+// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
+// - last_tokens: needed for repetition penalty, ignore if empty
+// - idx: sample from llama_get_logits_ith(ctx, idx)
+// - seq: sequence id to associate sampler state with
+//
+// returns:
+// - token: sampled token
+// - candidates: vector of candidate tokens
+//
+llama_token llama_sampling_sample(
+ struct llama_context * ctx,
+ struct llama_context * ctx_guidance,
+ struct llama_sampling_context & ctx_sampling,
+ const std::vector<llama_token> & last_tokens,
+ std::vector<llama_token_data> & candidates,
+ const int idx = 0,
+ llama_seq_id seq = 0);
diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp
index 99e6bdad..87a5a1c2 100644
--- a/examples/embd-input/embd-input-lib.cpp
+++ b/examples/embd-input/embd-input-lib.cpp
@@ -128,21 +128,22 @@ bool eval_string(struct MyModel * mymodel,const char* str){
llama_token sampling_id(struct MyModel* mymodel) {
llama_context* ctx = mymodel->ctx;
gpt_params params = mymodel->params;
+ llama_sampling_params & sparams = params.sampling_params;
// int n_ctx = llama_n_ctx(ctx);
// out of user input, sample next token
- const float temp = params.temp;
- const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : params.top_k;
- const float top_p = params.top_p;
- const float tfs_z = params.tfs_z;
- const float typical_p = params.typical_p;
+ const float temp = sparams.temp;
+ const int32_t top_k = sparams.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : sparams.top_k;
+ const float top_p = sparams.top_p;
+ const float tfs_z = sparams.tfs_z;
+ const float typical_p = sparams.typical_p;
// const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
// const float repeat_penalty = params.repeat_penalty;
// const float alpha_presence = params.presence_penalty;
// const float alpha_frequency = params.frequency_penalty;
- const int mirostat = params.mirostat;
- const float mirostat_tau = params.mirostat_tau;
- const float mirostat_eta = params.mirostat_eta;
+ const int mirostat = sparams.mirostat;
+ const float mirostat_tau = sparams.mirostat_tau;
+ const float mirostat_eta = sparams.mirostat_eta;
// const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
@@ -151,7 +152,7 @@ llama_token sampling_id(struct MyModel* mymodel) {
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
// Apply params.logit_bias map
- for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
+ for (auto it = sparams.logit_bias.begin(); it != sparams.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp
index d994de5e..187623f5 100644
--- a/examples/infill/infill.cpp
+++ b/examples/infill/infill.cpp
@@ -104,6 +104,7 @@ static void sigint_handler(int signo) {
int main(int argc, char ** argv) {
gpt_params params;
+ llama_sampling_params & sparams = params.sampling_params;
g_params = &params;
if (!gpt_params_parse(argc, argv, params)) {
@@ -206,7 +207,7 @@ int main(int argc, char ** argv) {
// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
- if (params.cfg_scale > 1.f) {
+ if (sparams.cfg_scale > 1.f) {
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
ctx_guidance = llama_new_context_with_model(model, lparams);
}
@@ -269,9 +270,9 @@ int main(int argc, char ** argv) {
int guidance_offset = 0;
int original_prompt_len = 0;
if (ctx_guidance) {
- LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt));
+ LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
- guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos);
+ guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
@@ -312,7 +313,7 @@ int main(int argc, char ** argv) {
if (ctx_guidance) {
LOG_TEE("\n");
- LOG_TEE("%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
+ LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str());
LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
for (int i = 0; i < (int) guidance_inp.size(); i++) {
LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
@@ -358,7 +359,7 @@ int main(int argc, char ** argv) {
}
}
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
- params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
+ sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");
@@ -376,8 +377,8 @@ int main(int argc, char ** argv) {
LOG_TEE("\n");
{
- auto it = params.logit_bias.find(llama_token_eos(ctx));
- if (it != params.logit_bias.end() && it->second == -INFINITY) {
+ auto it = sparams.logit_bias.find(llama_token_eos(ctx));
+ if (it != sparams.logit_bias.end() && it->second == -INFINITY) {
LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
}
}
@@ -434,6 +435,7 @@ int main(int argc, char ** argv) {
const int n_vocab = llama_n_vocab(model);
+ llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@@ -552,7 +554,7 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
- const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates);
+ const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id);
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 775a5a20..b39a67d9 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -109,6 +109,7 @@ int main(int argc, char ** argv) {
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
+ llama_sampling_params & sparams = params.sampling_params;
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("main", "log"));
@@ -179,7 +180,7 @@ int main(int argc, char ** argv) {
// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
- if (params.cfg_scale > 1.f) {
+ if (sparams.cfg_scale > 1.f) {
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
ctx_guidance = llama_new_context_with_model(model, lparams);
}
@@ -257,9 +258,9 @@ int main(int argc, char ** argv) {
int guidance_offset = 0;
int original_prompt_len = 0;
if (ctx_guidance) {
- LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt));
+ LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
- guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos);
+ guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
@@ -343,7 +344,7 @@ int main(int argc, char ** argv) {
if (ctx_guidance) {
LOG_TEE("\n");
- LOG_TEE("%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
+ LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str());
LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
for (int i = 0; i < (int) guidance_inp.size(); i++) {
LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
@@ -395,7 +396,7 @@ int main(int argc, char ** argv) {
}
}
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
- params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
+ sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");
@@ -413,8 +414,8 @@ int main(int argc, char ** argv) {
LOG_TEE("\n");
{
- auto it = params.logit_bias.find(llama_token_eos(ctx));
- if (it != params.logit_bias.end() && it->second == -INFINITY) {
+ auto it = sparams.logit_bias.find(llama_token_eos(ctx));
+ if (it != sparams.logit_bias.end() && it->second == -INFINITY) {
LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
}
}
@@ -469,6 +470,7 @@ int main(int argc, char ** argv) {
const int n_vocab = llama_n_vocab(model);
+ llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@@ -625,7 +627,7 @@ int main(int argc, char ** argv) {
LOG("saved session to %s\n", path_session.c_str());
}
- const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates);
+ const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id);
diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp
index 04f1e45b..63ddcd8e 100644
--- a/examples/parallel/parallel.cpp
+++ b/examples/parallel/parallel.cpp
@@ -125,6 +125,8 @@ int main(int argc, char ** argv) {
params.logits_all = true;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
+ llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL);
+
// load the prompts from an external file if there are any
if (params.prompt.empty()) {
printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
@@ -339,7 +341,7 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
- const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i);
+ const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id);
if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients
@@ -384,7 +386,7 @@ int main(int argc, char ** argv) {
n_total_prompt += client.n_prompt;
n_total_gen += client.n_decoded;
-
+ llama_sampling_context_reset(ctx_sampling, client.seq_id);
client.seq_id = -1;
}
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp
index acc6dbdf..f9e3c98a 100644
--- a/examples/save-load-state/save-load-state.cpp
+++ b/examples/save-load-state/save-load-state.cpp
@@ -8,9 +8,10 @@
int main(int argc, char ** argv) {
gpt_params params;
+ llama_sampling_params & sparams = params.sampling_params;
params.seed = 42;
params.n_threads = 4;
- params.repeat_last_n = 64;
+ sparams.repeat_last_n = 64;
params.prompt = "The quick brown fox";
if (!gpt_params_parse(argc, argv, params)) {
@@ -24,7 +25,7 @@ int main(int argc, char ** argv) {
}
auto n_past = 0;
- auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
+ auto last_n_tokens_data = std::vector<llama_token>(sparams.repeat_last_n, 0);
// init
llama_model * model;
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 8c5318c6..58af78de 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -200,6 +200,7 @@ struct llama_server_context
llama_model *model = nullptr;
llama_context *ctx = nullptr;
gpt_params params;
+ llama_sampling_context ctx_sampling;
int n_ctx;
grammar_parser::parse_state parsed_grammar;
@@ -254,6 +255,7 @@ struct llama_server_context
if (grammar != nullptr) {
llama_grammar_free(grammar);
grammar = nullptr;
+ ctx_sampling = llama_sampling_context_init(params, NULL);
}
}
@@ -329,8 +331,8 @@ struct llama_server_context
grammar_parser::print_grammar(stderr, parsed_grammar);
{
- auto it = params.logit_bias.find(llama_token_eos(ctx));
- if (it != params.logit_bias.end() && it->second == -INFINITY) {
+ auto it = params.sampling_params.logit_bias.find(llama_token_eos(ctx));
+ if (it != params.sampling_params.logit_bias.end() && it->second == -INFINITY) {
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
}
}
@@ -339,6 +341,7 @@ struct llama_server_context
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
+ ctx_sampling = llama_sampling_context_init(params, grammar);
return true;
}
@@ -550,12 +553,12 @@ struct llama_server_context
std::vector<llama_token_data> candidates;
candidates.reserve(llama_n_vocab(model));
- result.tok = llama_sample_token(ctx, NULL, grammar, params, last_n_tokens, candidates);
+ result.tok = llama_sampling_sample(ctx, NULL, ctx_sampling, last_n_tokens, candidates);
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
- const int32_t n_probs = params.n_probs;
- if (params.temp <= 0 && n_probs > 0)
+ const int32_t n_probs = params.sampling_params.n_probs;
+ if (params.sampling_params.temp <= 0 && n_probs > 0)
{
// For llama_sample_token_greedy we need to sort candidates
llama_sample_softmax(ctx, &candidates_p);
@@ -630,7 +633,7 @@ struct llama_server_context
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
generated_text += token_text;
- if (params.n_probs > 0)
+ if (params.sampling_params.n_probs > 0)
{
generated_token_probs.push_back(token_with_probs);
}
@@ -1018,34 +1021,35 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
static json format_generation_settings(llama_server_context &llama)
{
- const auto eos_bias = llama.params.logit_bias.find(llama_token_eos(llama.ctx));
- const bool ignore_eos = eos_bias != llama.params.logit_bias.end() &&
+ const auto & sparams = llama.params.sampling_params;
+ const auto eos_bias = sparams.logit_bias.find(llama_token_eos(llama.ctx));
+ const bool ignore_eos = eos_bias != sparams.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
return json{
{"n_ctx", llama.n_ctx},
{"model", llama.params.model_alias},
{"seed", llama.params.seed},
- {"temp", llama.params.temp},
- {"top_k", llama.params.top_k},
- {"top_p", llama.params.top_p},
- {"tfs_z", llama.params.tfs_z},
- {"typical_p", llama.params.typical_p},
- {"repeat_last_n", llama.params.repeat_last_n},
- {"repeat_penalty", llama.params.repeat_penalty},
- {"presence_penalty", llama.params.presence_penalty},
- {"frequency_penalty", llama.params.frequency_penalty},
- {"mirostat", llama.params.mirostat},
- {"mirostat_tau", llama.params.mirostat_tau},
- {"mirostat_eta", llama.params.mirostat_eta},
- {"penalize_nl", llama.params.penalize_nl},
+ {"temp", sparams.temp},
+ {"top_k", sparams.top_k},
+ {"top_p", sparams.top_p},
+ {"tfs_z", sparams.tfs_z},
+ {"typical_p", sparams.typical_p},
+ {"repeat_last_n", sparams.repeat_last_n},
+ {"repeat_penalty", sparams.repeat_penalty},
+ {"presence_penalty", sparams.presence_penalty},
+ {"frequency_penalty", sparams.frequency_penalty},
+ {"mirostat", sparams.mirostat},
+ {"mirostat_tau", sparams.mirostat_tau},
+ {"mirostat_eta", sparams.mirostat_eta},
+ {"penalize_nl", sparams.penalize_nl},
{"stop", llama.params.antiprompt},
{"n_predict", llama.params.n_predict},
{"n_keep", llama.params.n_keep},
{"ignore_eos", ignore_eos},
{"stream", llama.stream},
- {"logit_bias", llama.params.logit_bias},
- {"n_probs", llama.params.n_probs},
+ {"logit_bias", sparams.logit_bias},
+ {"n_probs", sparams.n_probs},
{"grammar", llama.params.grammar},
};
}
@@ -1094,7 +1098,7 @@ static json format_final_response(llama_server_context &llama, const std::string
{"timings", format_timings(llama)},
};
- if (llama.params.n_probs > 0)
+ if (llama.params.sampling_params.n_probs > 0)
{
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
}
@@ -1110,7 +1114,7 @@ static json format_partial_response(
{"stop", false},
};
- if (llama.params.n_probs > 0)
+ if (llama.params.sampling_params.n_probs > 0)
{
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
}
@@ -1142,26 +1146,28 @@ static T json_value(const json &body, const std::string &key, const T &default_v
static void parse_options_completion(const json &body, llama_server_context &llama)
{
gpt_params default_params;
+ const auto & default_sparams = default_params.sampling_params;
+ auto & sparams = llama.params.sampling_params;
llama.stream = json_value(body, "stream", false);
llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict);
- llama.params.top_k = json_value(body, "top_k", default_params.top_k);
- llama.params.top_p = json_value(body, "top_p", default_params.top_p);
- llama.params.tfs_z = json_value(body, "tfs_z", default_params.tfs_z);
- llama.params.typical_p = json_value(body, "typical_p", default_params.typical_p);
- llama.params.repeat_last_n = json_value(body, "repeat_last_n", default_params.repeat_last_n);
- llama.params.temp = json_value(body, "temperature", default_params.temp);
- llama.params.repeat_penalty = json_value(body, "repeat_penalty", default_params.repeat_penalty);
- llama.params.presence_penalty = json_value(body, "presence_penalty", default_params.presence_penalty);
- llama.params.frequency_penalty = json_value(body, "frequency_penalty", default_params.frequency_penalty);
- llama.params.mirostat = json_value(body, "mirostat", default_params.mirostat);
- llama.params.mirostat_tau = json_value(body, "mirostat_tau", default_params.mirostat_tau);
- llama.params.mirostat_eta = json_value(body, "mirostat_eta", default_params.mirostat_eta);
- llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl);
+ sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
+ sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
+ sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z);
+ sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p);
+ sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n);
+ sparams.temp = json_value(body, "temperature", default_sparams.temp);
+ sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty);
+ sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty);
+ sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty);
+ sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat);
+ sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
+ sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
+ sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
llama.params.seed = json_value(body, "seed", default_params.seed);
llama.params.grammar = json_value(body, "grammar", default_params.grammar);
- llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs);
+ sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
if (body.count("prompt") != 0)
{
@@ -1172,10 +1178,10 @@ static void parse_options_completion(const json &body, llama_server_context &lla
llama.prompt = "";
}
- llama.params.logit_bias.clear();
+ sparams.logit_bias.clear();
if (json_value(body, "ignore_eos", false))
{
- llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
+ sparams.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
}
const auto &logit_bias = body.find("logit_bias");
@@ -1191,11 +1197,11 @@ static void parse_options_completion(const json &body, llama_server_context &lla
{
if (el[1].is_number())
{
- llama.params.logit_bias[tok] = el[1].get<float>();
+ sparams.logit_bias[tok] = el[1].get<float>();
}
else if (el[1].is_boolean() && !el[1].get<bool>())
{
- llama.params.logit_bias[tok] = -INFINITY;
+ sparams.logit_bias[tok] = -INFINITY;
}
}
}
@@ -1215,6 +1221,8 @@ static void parse_options_completion(const json &body, llama_server_context &lla
}
}
+ llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar);
+
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
}
@@ -1423,7 +1431,7 @@ int main(int argc, char **argv)
}
auto probs = llama.generated_token_probs;
- if (llama.params.n_probs > 0 && llama.stopped_word) {
+ if (llama.params.sampling_params.n_probs > 0 && llama.stopped_word) {
const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false);
probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size());
}
@@ -1475,7 +1483,7 @@ int main(int argc, char **argv)
std::vector<completion_token_output> probs_output = {};
- if (llama.params.n_probs > 0) {
+ if (llama.params.sampling_params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
@@ -1596,7 +1604,7 @@ int main(int argc, char **argv)
std::vector<completion_token_output> probs_output = {};
- if (llama.params.n_probs > 0) {
+ if (llama.params.sampling_params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp
index 75a2e5e2..018dbf9a 100644
--- a/examples/speculative/speculative.cpp
+++ b/examples/speculative/speculative.cpp
@@ -125,6 +125,8 @@ int main(int argc, char ** argv) {
grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
+ llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar_tgt);
+
const auto t_dec_start = ggml_time_us();
while (true) {
@@ -134,7 +136,7 @@ int main(int argc, char ** argv) {
while (true) {
// sample from the target model
- llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
+ llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, i_dft);
// remember which tokens were sampled - used for repetition penalties during sampling
last_tokens.erase(last_tokens.begin());
@@ -211,7 +213,13 @@ int main(int argc, char ** argv) {
if (grammar_dft) {
llama_grammar_free(grammar_dft);
}
- grammar_dft = llama_grammar_copy(grammar_tgt);
+ // Note: Hardcoded to sequence id 0, if this ever supports parallel generation
+ // that will need to change.
+ auto it = ctx_sampling.sequence_contexts.find(0);
+ GGML_ASSERT(it != ctx_sampling.sequence_contexts.end());
+ // This is necessary because each sequence id in sequence_contexts
+ // uses a copy of the original grammar.
+ grammar_dft = llama_grammar_copy(it->second.grammar);
LOG("copied target grammar to draft grammar\n");
}