summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-07-27 07:55:01 +0200
committerGitHub <noreply@github.com>2024-07-27 07:55:01 +0200
commit154e0d75fccf1784fe9ff6fd76a630b66563da3d (patch)
tree81ce6dbb5b1900c1aa78a879f0593c694cab9d27 /tests
parent0684c3e9c70d49323b4fc517128cbe222cab7f96 (diff)
Merge mainline llama.cpp (#3)
* Merging mainline - WIP * Merging mainline - WIP AVX2 and CUDA appear to work. CUDA performance seems slightly (~1-2%) lower as it is so often the case with llama.cpp/ggml after some "improvements" have been made. * Merging mainline - fix Metal * Remove check --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/CMakeLists.txt32
-rw-r--r--tests/test-backend-ops.cpp104
-rw-r--r--tests/test-chat-template.cpp60
-rw-r--r--tests/test-double-float.cpp4
-rw-r--r--tests/test-grammar-integration.cpp646
-rwxr-xr-xtests/test-json-schema-to-grammar.cpp455
-rw-r--r--tests/test-llama-grammar.cpp24
-rw-r--r--tests/test-quantize-fns.cpp4
-rw-r--r--tests/test-quantize-perf.cpp2
-rw-r--r--tests/test-rope.cpp1
-rw-r--r--tests/test-tokenizer-0.cpp12
-rw-r--r--tests/test-tokenizer-1-bpe.cpp35
-rw-r--r--tests/test-tokenizer-1-spm.cpp33
-rw-r--r--tests/test-tokenizer-random.py345
14 files changed, 1353 insertions, 404 deletions
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index cfa70731..0207e3a5 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -70,21 +70,19 @@ add_executable(test-tokenizer-0 test-tokenizer-0.cpp)
target_link_libraries(test-tokenizer-0 PRIVATE common)
install(TARGETS test-tokenizer-0 RUNTIME)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-spm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-spm.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-phi-3 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-phi-3.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
llama_test(test-tokenizer-0 NAME test-tokenizer-0-bert-bge ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-bert-bge.gguf)
-# TODO: enable when fixed
-# https://github.com/ggerganov/llama.cpp/pull/7036
-#llama_test(test-tokenizer-0 NAME test-tokenizer-0-mpt ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
-#llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-llm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-llm.gguf)
-#llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-coder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-coder.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-gpt-2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-2.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
llama_test(test-tokenizer-0 NAME test-tokenizer-0-command-r ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-command-r.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-coder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-coder.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-llm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-llm.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-gpt-2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-2.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-spm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-spm.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-mpt ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-phi-3 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-phi-3.gguf)
llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-qwen2.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
# build test-tokenizer-1-bpe target once and add many tests
add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp)
@@ -92,16 +90,14 @@ target_link_libraries(test-tokenizer-1-bpe PRIVATE common)
install(TARGETS test-tokenizer-1-bpe RUNTIME)
# TODO: disabled due to slowness
-#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf --ignore-merges)
-#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-aquila ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
-#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-mpt ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
-#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-stablelm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-stablelm.gguf)
+#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
+#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-2.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-neox ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
+#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf --ignore-merges)
+#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-mpt ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
-#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt2.gguf)
-#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-bloom ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-bloom.gguf)
# build test-tokenizer-1-spm target once and add many tests
add_executable(test-tokenizer-1-spm test-tokenizer-1-spm.cpp)
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 7c504e93..2c03c60d 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -1,7 +1,6 @@
#include <ggml.h>
#include <ggml-alloc.h>
#include <ggml-backend.h>
-#include <ggml-backend-impl.h>
#include <algorithm>
#include <array>
@@ -80,8 +79,16 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
im = nullptr;
}
}
+
ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], im);
GGML_ASSERT(ggml_validate_row_data(tensor->type, dataq.data(), dataq.size()));
+ // TODO: other cases
+ //#pragma omp parallel for
+ //for (int i = 0; i < tensor->ne[1]; i++) {
+ // ggml_quantize_chunk(tensor->type, data.data(), dataq.data(),
+ // i * tensor->ne[0], 1, tensor->ne[0], im);
+ //}
+
ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
} else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
// This is going to create some weird integers though.
@@ -760,7 +767,7 @@ struct test_dup : public test_case {
}
test_dup(ggml_type type = GGML_TYPE_F32,
- std::array<int64_t, 4> ne = {10, 10, 10, 1},
+ std::array<int64_t, 4> ne = {10, 10, 20, 1},
std::array<int64_t, 4> permute = {0, 0, 0, 0})
: type(type), ne(ne), permute(permute),
_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
@@ -780,9 +787,15 @@ struct test_cpy : public test_case {
const ggml_type type_src;
const ggml_type type_dst;
const std::array<int64_t, 4> ne;
+ const std::array<int64_t, 4> permute;
+ bool _src_use_permute;
std::string vars() override {
- return VARS_TO_STR3(type_src, type_dst, ne);
+ return VARS_TO_STR4(type_src, type_dst, ne, permute);
+ }
+
+ double max_nmse_err() override {
+ return 1e-6;
}
size_t op_size(ggml_tensor * t) override {
@@ -790,12 +803,18 @@ struct test_cpy : public test_case {
}
test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
- std::array<int64_t, 4> ne = {10, 10, 10, 1})
- : type_src(type_src), type_dst(type_dst), ne(ne) {}
+ std::array<int64_t, 4> ne = {10, 10, 10, 1},
+ std::array<int64_t, 4> permute = {0, 0, 0, 0},
+ bool _dst_use_permute = false)
+ : type_src(type_src), type_dst(type_dst), ne(ne), permute(permute),
+ _src_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
- ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, ne.data());
+ if (_src_use_permute) {
+ src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]);
+ }
+ ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
ggml_tensor * out = ggml_cpy(ctx, src, dst);
return out;
}
@@ -1171,6 +1190,7 @@ struct test_soft_max : public test_case {
}
};
+
// GGML_OP_ROPE
struct test_rope : public test_case {
const ggml_type type;
@@ -1263,6 +1283,32 @@ struct test_pool2d : public test_case {
}
};
+// GGML_OP_CONV_TRANSPOSE_1D
+struct test_conv_transpose_1d : public test_case {
+ const std::array<int64_t, 4> ne_input;
+ const std::array<int64_t, 4> ne_kernel;
+
+ const int s0; // stride
+ const int p0; // padding
+ const int d0; // dilation
+
+ std::string vars() override {
+ return VARS_TO_STR5(ne_input, ne_kernel, s0, p0, d0);
+ }
+
+ test_conv_transpose_1d(std::array<int64_t, 4> ne_input = {197, 32, 1, 1}, // [input_width, input_height, input_channels, 1]
+ std::array<int64_t, 4> ne_kernel = {16, 32, 32, 1}, // [kernel_width, kernel_height, input_channels, 1]
+ int s0 = 1, int p0 = 0, int d0 = 1)
+ : ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), p0(p0), d0(d0) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
+ ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
+ ggml_tensor * out = ggml_conv_transpose_1d(ctx, kernel, input, s0, p0, d0);
+ return out;
+ }
+};
+
// GGML_OP_IM2COL
struct test_im2col : public test_case {
const ggml_type type_input;
@@ -1276,7 +1322,7 @@ struct test_im2col : public test_case {
// padding
const int p0;
const int p1;
- // dilatation
+ // dilation
const int d0;
const int d1;
// mode
@@ -2049,6 +2095,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
+ GGML_TYPE_BF16,
};
// unary ops
@@ -2094,6 +2141,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
+ test_cases.emplace_back(new test_conv_transpose_1d());
+ test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
+ test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));
+ test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 1, 0, 1));
+ test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 2, 0, 1));
+ test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 1, 0, 1));
+ test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
+ test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
+
+
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1}));
@@ -2106,12 +2163,22 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
test_cases.emplace_back(new test_dup(GGML_TYPE_I32));
test_cases.emplace_back(new test_dup(GGML_TYPE_I16));
+ test_cases.emplace_back(new test_dup(GGML_TYPE_F32, {10, 10, 5, 1}, {0, 2, 1, 3}));
+ test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {0, 2, 1, 3})); // dup by rows
+ test_cases.emplace_back(new test_dup(GGML_TYPE_F32, {10, 10, 5, 1}, {1, 0, 2, 3}));
+ test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {1, 0, 2, 3})); // dup dst not-contiguous
test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3}));
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (ggml_type type_dst : all_types) {
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
+ test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
+ }
+ }
+ for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
+ for (ggml_type type_dst : {GGML_TYPE_F16, GGML_TYPE_F32}) {
+ test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous
}
}
@@ -2161,6 +2228,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
}
+#if 1
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
@@ -2180,6 +2248,24 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
}
}
+#else
+ // m = a rows
+ // n = b rows
+ // k = cols
+ std::uniform_int_distribution<> dist_m(1, 128);
+ std::uniform_int_distribution<> dist_n(16, 128);
+ std::uniform_int_distribution<> dist_k(1, 16);
+ for (int i = 0; i < 1000; i++) {
+ for (ggml_type type_a : all_types) {
+ for (ggml_type type_b : {GGML_TYPE_F32}) {
+ int m = dist_m(rng);
+ int n = dist_n(rng);
+ int k = dist_k(rng) * ggml_blck_size(type_a);
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, m, n, k, { 1, 1}, {1, 1}));
+ }
+ }
+ }
+#endif
for (ggml_type type_a : other_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
@@ -2243,7 +2329,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (int n = 0; n < 10; ++n) {
int64_t ne0 = dist_ne0(rng);
int64_t ne1 = dist_ne1(rng);
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f));
}
exponent <<= 1;
@@ -2262,7 +2348,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
}
}
-
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp
index cef9a650..a8222cae 100644
--- a/tests/test-chat-template.cpp
+++ b/tests/test-chat-template.cpp
@@ -1,4 +1,3 @@
-#include <iostream>
#include <string>
#include <vector>
#include <sstream>
@@ -7,6 +6,7 @@
#include <cassert>
#include "llama.h"
+#include "common.h"
int main(void) {
llama_chat_message conversation[] = {
@@ -56,7 +56,15 @@ int main(void) {
//Phi-3-medium
"{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
//Phi-3-vision
- "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}"
+ "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
+ // ChatGLM3
+ "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
+ // ChatGLM4
+ u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
+ // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
+ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
+ // DeepSeek-V2
+ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
};
std::vector<std::string> expected_output = {
// teknium/OpenHermes-2.5-Mistral-7B
@@ -93,6 +101,14 @@ int main(void) {
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
//Phi-3-vision
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
+ // ChatGLM3
+ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
+ // ChatGLM4
+ "[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
+ // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
+ u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
+ // DeepSeek-V2
+ u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:",
};
std::vector<char> formatted_chat(1024);
int32_t res;
@@ -116,8 +132,46 @@ int main(void) {
);
formatted_chat.resize(res);
std::string output(formatted_chat.data(), formatted_chat.size());
- std::cout << output << "\n-------------------------\n";
+ printf("%s\n", output.c_str());
+ printf("-------------------------\n");
assert(output == expected);
}
+
+
+ // test llama_chat_format_single for system message
+ printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
+ std::vector<llama_chat_msg> chat2;
+ llama_chat_msg sys_msg{"system", "You are a helpful assistant"};
+
+ auto fmt_sys = [&](std::string tmpl) {
+ auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false);
+ printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str());
+ printf("-------------------------\n");
+ return output;
+ };
+ assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
+ assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n");
+ assert(fmt_sys("gemma") == ""); // for gemma, system message is merged with user message
+ assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>");
+
+
+ // test llama_chat_format_single for user message
+ printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
+ chat2.push_back({"system", "You are a helpful assistant"});
+ chat2.push_back({"user", "Hello"});
+ chat2.push_back({"assistant", "I am assistant"});
+ llama_chat_msg new_msg{"user", "How are you"};
+
+ auto fmt_single = [&](std::string tmpl) {
+ auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
+ printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str());
+ printf("-------------------------\n");
+ return output;
+ };
+ assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
+ assert(fmt_single("llama2") == "[INST] How are you [/INST]");
+ assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
+ assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
+
return 0;
}
diff --git a/tests/test-double-float.cpp b/tests/test-double-float.cpp
index 753dae91..6aac4737 100644
--- a/tests/test-double-float.cpp
+++ b/tests/test-double-float.cpp
@@ -14,7 +14,7 @@
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdouble-promotion"
-// ggml.c::quantize_row_q4_0_reference
+// ggml.c::quantize_row_q4_0_ref
inline static uint8_t round_orig(float v0) { return ((int8_t) (round(v0))) + 8; }
// ggml.c::ggml_silu_f32
@@ -24,7 +24,7 @@ inline static float silu_orig(float x) {
#pragma GCC diagnostic pop
-// ggml.c::quantize_row_q4_0_reference
+// ggml.c::quantize_row_q4_0_ref
inline static uint8_t round_float(float v0) { return (int8_t)roundf(v0) + 8; }
// ggml.c::ggml_silu_f32
diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp
index 96f90c01..68f971bf 100644
--- a/tests/test-grammar-integration.cpp
+++ b/tests/test-grammar-integration.cpp
@@ -15,8 +15,6 @@
using json = nlohmann::ordered_json;
-//#define INCLUDE_FAILING_TESTS 1
-
static llama_grammar* build_grammar(const std::string & grammar_str) {
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
@@ -36,31 +34,36 @@ static llama_grammar* build_grammar(const std::string & grammar_str) {
static bool test_build_grammar_fails(const std::string & grammar_str) {
fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str());
bool grammar_fails = false;
- try {
- build_grammar(grammar_str);
+ llama_grammar * grammar = build_grammar(grammar_str);
+ if (grammar != nullptr) {
fprintf(stderr, " ❌ Expected build failure, but succeeded\n");
- } catch (const std::exception & err) {
+ } else {
grammar_fails = true;
fprintf(stdout, " ✅︎\n");
}
return grammar_fails;
}
-static bool match_string(const std::string & input, llama_grammar* grammar) {
+static bool match_string(const std::string & input, llama_grammar * grammar) {
auto decoded = decode_utf8(input, {});
const auto & code_points = decoded.first;
+ const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
+ llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
+
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
- auto prev_stacks = grammar->stacks;
- llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
- if (grammar->stacks.empty()) {
+ const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
+
+ llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
+
+ if (cur_stacks.empty()) {
// no stacks means that the grammar failed to match at this point
return false;
}
}
- for (const auto & stack : grammar->stacks) {
+ for (const auto & stack : cur_stacks) {
if (stack.empty()) {
// An empty stack means that the grammar has been completed
return true;
@@ -77,7 +80,9 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
auto grammar = build_grammar(grammar_str);
// Save the original grammar stacks so that we can reset after every new string we want to test
- auto original_stacks = grammar->stacks;
+ const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar);
+
+ llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
fprintf(stderr, " 🔵 Valid strings:\n");
@@ -114,7 +119,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
assert(matched);
// Reset the grammar stacks
- grammar->stacks = original_stacks;
+ cur_stacks = original_stacks;
}
fprintf(stderr, " 🟠 Invalid strings:\n");
@@ -134,7 +139,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
assert(!matched);
// Reset the grammar stacks
- grammar->stacks = original_stacks;
+ cur_stacks = original_stacks;
}
// Clean up allocated memory
@@ -148,6 +153,250 @@ static void test_schema(const std::string & test_desc, const std::string & schem
}
static void test_simple_grammar() {
+ test_schema(
+ "min 0",
+ R"""({
+ "type": "integer",
+ "minimum": 0
+ })""",
+ // Passing strings
+ {
+ "0",
+ "10",
+ "12",
+ "10000",
+ },
+ // Failing strings
+ {
+ "-1",
+ "-10",
+ "-10000",
+ "-100000000000000000000000000000000",
+ "100000000000000000000000000000000",
+ "00",
+ "01",
+ "-0",
+ }
+ );
+ test_schema(
+ "min 2",
+ // Schema
+ R"""({
+ "type": "integer",
+ "minimum": 2
+ })""",
+ // Passing strings
+ {
+ "2",
+ "3",
+ "4",
+ "10",
+ "20",
+ "1234567890000000",
+ },
+ // Failing strings
+ {
+ "0",
+ "1",
+ "-1",
+ "-100",
+ "0",
+ "1",
+ "01",
+ "02",
+ "12345678900000000",
+ }
+ );
+ test_schema(
+ "min 456",
+ R"""({
+ "type": "integer",
+ "minimum": 456
+ })""",
+ // Passing strings
+ {
+ "456",
+ "4560",
+ "457",
+ "460",
+ "500",
+ },
+ // Failing strings
+ {
+ "455",
+ "356",
+ "50",
+ "050",
+ "-1",
+ "-456",
+ }
+ );
+ test_schema(
+ "min -123",
+ R"""({
+ "type": "integer",
+ "minimum": -123
+ })""",
+ // Passing strings
+ {
+ "-123",
+ "-122",
+ "-11",
+ "-1",
+ "0",
+ "1",
+ "123",
+ "1234",
+ "2345",
+ },
+ // Failing strings
+ {
+ "-1234",
+ "-124",
+ }
+ );
+
+ test_schema(
+ "max 9999",
+ // Schema
+ R"""({
+ "type": "integer",
+ "maximum": 9999
+ })""",
+ // Passing strings
+ {
+ "-99999",
+ "0",
+ "9999",
+ },
+ // Failing strings
+ {
+ "10000",
+ "99991",
+ }
+ );
+ test_schema(
+ "max -9999",
+ // Schema
+ R"""({
+ "type": "integer",
+ "maximum": -9999
+ })""",
+ // Passing strings
+ {
+ "-10000",
+ "-9999",
+ },
+ // Failing strings
+ {
+ "-9998",
+ "0",
+ "9999",
+ }
+ );
+ test_schema(
+ "min 5 max 30",
+ // Schema
+ R"""({
+ "type": "integer",
+ "minimum": 5,
+ "maximum": 30
+ })""",
+ // Passing strings
+ {
+ "5",
+ "10",
+ "30",
+ },
+ // Failing strings
+ {
+ "05",
+ "4",
+ "-1",
+ "31",
+ "123",
+ "0123",
+ }
+ );
+ test_schema(
+ "min -1 max 1",
+ R"""({
+ "type": "integer",
+ "minimum": -1,
+ "maximum": 1
+ })""",
+ // Passing strings
+ {
+ "-1",
+ "0",
+ "1",
+ },
+ // Failing strings
+ {
+ "-11",
+ "-10",
+ "-2",
+ "2",
+ "10",
+ "11",
+ }
+ );
+ test_schema(
+ "min -123 max 42",
+ R"""({
+ "type": "integer",
+ "minimum": -123,
+ "maximum": 42
+ })""",
+ // Passing strings
+ {
+ "-123",
+ "-122",
+ "-13",
+ "-11",
+ "-2",
+ "-1",
+ "0",
+ "1",
+ "5",
+ "10",
+ "39",
+ "40",
+ "42",
+ },
+ // Failing strings
+ {
+ "-0123",
+ "-124",
+ "-1123",
+ "-200",
+ "43",
+ "123",
+ "0123",
+ }
+ );
+ test_schema(
+ "exclusive min / max",
+ // Schema
+ R"""({
+ "type": "integer",
+ "exclusiveMinimum": 0,
+ "exclusiveMaximum": 10000
+ })""",
+ // Passing strings
+ {
+ "1",
+ "9999",
+ },
+ // Failing strings
+ {
+ "0",
+ "01",
+ "10000",
+ "99999",
+ }
+ );
+
// Test case for a simple grammar
test_grammar(
"simple grammar",
@@ -510,7 +759,7 @@ static void test_json_schema() {
)""",
// Passing strings
{
- "{}",
+ R"""({})""",
R"""({"foo": "bar"})""",
},
// Failing strings
@@ -518,7 +767,7 @@ static void test_json_schema() {
"",
"[]",
"null",
- "\"\"",
+ R"""("")""",
"true",
}
);
@@ -526,16 +775,14 @@ static void test_json_schema() {
test_schema(
"exotic formats (list)",
// Schema
- R"""(
- {
+ R"""({
"items": [
{ "format": "date" },
{ "format": "uuid" },
{ "format": "time" },
{ "format": "date-time" }
]
- }
- )""",
+ })""",
// Passing strings
{
// "{}", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
@@ -554,125 +801,113 @@ static void test_json_schema() {
test_schema(
"string",
// Schema
- R"""(
- {
- "type": "string"
- }
- )""",
+ R"""({
+ "type": "string"
+ })""",
// Passing strings
{
- "\"foo\"",
- "\"bar\"",
- "\"\"",
+ R"""("foo")""",
+ R"""("bar")""",
+ R"""("")""",
},
// Failing strings
{
- "{}",
- "\"foo\": \"bar\"",
+ R"""({})""",
+ R"""("foo": "bar")""",
}
);
test_schema(
"string w/ min length 1",
// Schema
- R"""(
- {
- "type": "string",
- "minLength": 1
- }
- )""",
+ R"""({
+ "type": "string",
+ "minLength": 1
+ })""",
// Passing strings
{
- "\"foo\"",
- "\"bar\"",
+ R"""("foo")""",
+ R"""("bar")""",
},
// Failing strings
{
- "\"\"",
- "{}",
- "\"foo\": \"bar\"",
+ R"""("")""",
+ R"""({})""",
+ R"""("foo": "bar")""",
}
);
test_schema(
"string w/ min length 3",
// Schema
- R"""(
- {
+ R"""({
"type": "string",
"minLength": 3
- }
- )""",
+ })""",
// Passing strings
{
- "\"foo\"",
- "\"bar\"",
- "\"foobar\"",
+ R"""("foo")""",
+ R"""("bar")""",
+ R"""("foobar")""",
},
// Failing strings
{
- "\"\"",
- "\"f\"",
- "\"fo\"",
+ R"""("")""",
+ R"""("f")""",
+ R"""("fo")""",
}
);
test_schema(
"string w/ max length",
// Schema
- R"""(
- {
- "type": "string",
- "maxLength": 3
- }
- )""",
+ R"""({
+ "type": "string",
+ "maxLength": 3
+ })""",
// Passing strings
{
- "\"foo\"",
- "\"bar\"",
- "\"\"",
- "\"f\"",
- "\"fo\"",
+ R"""("foo")""",
+ R"""("bar")""",
+ R"""("")""",
+ R"""("f")""",
+ R"""("fo")""",
},
// Failing strings
{
- "\"foobar\"",
+ R"""("foobar")""",
}
);
test_schema(
"string w/ min & max length",
// Schema
- R"""(
- {
- "type": "string",
- "minLength": 1,
- "maxLength": 4
- }
- )""",
+ R"""({
+ "type": "string",
+ "minLength": 1,
+ "maxLength": 4
+ })""",
// Passing strings
{
- "\"foo\"",
- "\"bar\"",
- "\"f\"",
- "\"barf\"",
+ R"""("foo")""",
+ R"""("bar")""",
+ R"""("f")""",
+ R"""("barf")""",
},
// Failing strings
{
- "\"\"",
- "\"barfo\"",
- "\"foobar\"",
+ R"""("")""",
+ R"""("barfo")""",
+ R"""("foobar")""",
}
);
test_schema(
"boolean",
// Schema
- R"""(
- {
- "type": "boolean"
- }
- )""",
+ R"""({
+ "type": "boolean"
+ })""",
// Passing strings
{
"true",
@@ -680,123 +915,171 @@ static void test_json_schema() {
},
// Failing strings
{
- "\"\"",
- "\"true\"",
- "True",
- "FALSE",
+ R"""("")""",
+ R"""("true")""",
+ R"""(True)""",
+ R"""(FALSE)""",
}
);
test_schema(
"integer",
// Schema
- R"""(
- {
- "type": "integer"
- }
- )""",
+ R"""({
+ "type": "integer"
+ })""",
// Passing strings
{
- "0",
- "12345",
- "1234567890123456"
+ R"""(0)""",
+ R"""(12345)""",
+ R"""(1234567890123456)""",
},
// Failing strings
{
- "",
- "01",
- "007",
- "12345678901234567"
+ R"""()""",
+ R"""(01)""",
+ R"""(007)""",
+ R"""(12345678901234567 )""",
}
);
test_schema(
"string const",
// Schema
- R"""(
- {
- "const": "foo"
- }
- )""",
+ R"""({
+ "const": "foo"
+ })""",
// Passing strings
{
- "\"foo\"",
+ R"""("foo")""",
},
// Failing strings
{
- "foo",
- "\"bar\"",
+ R"""(foo)""",
+ R"""("bar")""",
}
);
test_schema(
"non-string const",
// Schema
- R"""(
- {
- "const": true
- }
- )""",
+ R"""({
+ "const": true
+ })""",
// Passing strings
{
- "true",
+ R"""(true)""",
},
// Failing strings
{
- "",
- "foo",
- "\"true\"",
+ R"""()""",
+ R"""(foo)""",
+ R"""("true")""",
}
);
test_schema(
"non-string const",
// Schema
+ R"""({
+ "enum": ["red", "amber", "green", null, 42, ["foo"]]
+ })""",
+ // Passing strings
+ {
+ R"""("red")""",
+ R"""(null)""",
+ R"""(42)""",
+ R"""(["foo"])""",
+ },
+ // Failing strings
+ {
+ R"""()""",
+ R"""(420)""",
+ R"""(true)""",
+ R"""(foo)""",
+ }
+ );
+
+ test_schema(
+ "simple pattern",
+ // Schema
+ R"""({
+ "pattern": "^[a-zA-Z0-9_-]*$"
+ })""",
+ // Passing strings
+ {
+ R"""("")""",
+ R"""("He_llo-12")""",
+ },
+ // Failing strings
+ {
+ R"""("!")""",
+ R"""("Hello World")""",
+ }
+ );
+
+ test_schema(
+ "pattern with escapes",
+ // Schema
+ R"""({
+ "pattern": "^a\\^\\$\\.\\[\\]\\(\\)\\|\\{\\}\\*\\+\\?b$"
+ })""",
+ // Passing strings
+ {
+ R"""("a^$.[]()|{}*+?b")""",
+ },
+ // Failing strings
+ {
+ R"""("ab")""",
+ }
+ );
+
+ test_schema(
+ "",
+ // Schema
R"""(
{
- "enum": ["red", "amber", "green", null, 42, ["foo"]]
+ "type": ["array", "null"],
+ "items": { "type": "string" }
}
)""",
// Passing strings
{
- "\"red\"",
"null",
- "42",
- "[\"foo\"]",
+ "[]",
+ "[\"123\"]",
+ "[\"foo\", \"bar\"]",
},
// Failing strings
{
"",
- "420",
- "true",
- "foo",
+ "[123]",
+ "\"foo\"",
+ "[\"foo\", 42]",
}
);
-
test_schema(
"min+max items",
// Schema
- R"""(
- {
- "items": {
- "type": ["number", "integer"]
- },
- "minItems": 3,
- "maxItems": 5
- }
- )""",
+ R"""({
+ "items": {
+ "type": ["number", "integer"]
+ },
+ "minItems": 3,
+ "maxItems": 5
+ })""",
// Passing strings
{
- "[1, 2, 3]",
- "[1, 2, 3, 4]",
- "[1, 2, 3, 4, 5]",
+ R"""([1, 2, 3])""",
+ R"""([1, 2, 3, 4])""",
+ R"""([1, 2, 3, 4, 5])""",
},
// Failing strings
{
- "[1, 2]",
- "[1, 2, 3, 4, 5, 6]",
- "1"
+ R"""([1, 2])""",
+ R"""([1, 2, 3, 4, 5, 6])""",
+ R"""(1)""",
}
);
@@ -804,16 +1087,14 @@ static void test_json_schema() {
test_schema(
"object properties",
// Schema
- R"""(
- {
+ R"""({
"type": "object",
"properties": {
"number": { "type": "number" },
"street_name": { "type": "string" },
"street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
}
- }
- )""",
+ })""",
// Passing strings
{
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
@@ -822,13 +1103,7 @@ static void test_json_schema() {
R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
// "By extension, even an empty object is valid"
R"""({})""",
- // "By default, providing additional properties is valid"
-#ifdef INCLUDE_FAILING_TESTS
- // TODO: The following should pass, but currently FAILS. Additional properties should be permitted by default.
- R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
- // TODO: Spaces should be permitted around enum values, but currently they fail to pass.
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
-#endif
},
// Failing strings
{
@@ -838,16 +1113,41 @@ static void test_json_schema() {
R"""({ "street_name": "Pennsylvania", "number": 1600 })""",
// Reorder properties
R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+ // "Additional properties default to false for generation, even though the spec says true.
+ R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
+
}
);
+ test_schema(
+ "additional properties can't override other properties",
+ R"""({
+ "properties": {
+ "a": {"type": "integer"},
+ "b": {"type": "integer"}
+ },
+ "additionalProperties": true
+ })""",
+ // Passing strings
+ {
+ R"""({"a": 42})""",
+ R"""({"c": ""})""",
+ R"""({"a": 42, "c": ""})""",
+ R"""({"a_": ""})""",
+ },
+ // Failing strings
+ {
+ R"""()""",
+ R"""({"a": ""})""",
+ R"""({"a": "", "b": ""})""",
+ }
+ );
// Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
test_schema(
"object properties, additionalProperties: true",
// Schema
- R"""(
- {
+ R"""({
"type": "object",
"properties": {
"number": { "type": "number" },
@@ -855,26 +1155,18 @@ static void test_json_schema() {
"street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
},
"additionalProperties": true
- }
- )""",
+ })""",
// Passing strings
{
// "By extension, even an empty object is valid"
R"""({})""",
-#ifdef INCLUDE_FAILING_TESTS
- // TODO: Following line should pass and doesn't
R"""({"number":1600,"street_name":"Pennsylvania","street_type":"Avenue"})""",
// "By default, leaving out properties is valid"
- // TODO: Following line should pass and doesn't
R"""({ "street_name": "Pennsylvania" })""",
- // TODO: Following line should pass and doesn't
R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
// "By default, providing additional properties is valid"
- // TODO: The following should pass, but currently FAILS. Additional properties should be permitted by default.
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
- // TODO: Spaces should be permitted around enum values, but currently they fail to pass.
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
-#endif
},
// Failing strings
{
@@ -889,8 +1181,7 @@ static void test_json_schema() {
test_schema(
"required + optional props each in original order",
// Schema
- R"""(
- {
+ R"""({
"type": "object",
"properties": {
"number": { "type": "number" },
@@ -898,18 +1189,15 @@ static void test_json_schema() {
"street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
},
"additionalProperties": false
- }
- )""",
+ })""",
// Passing strings
{
R"""({ "street_name": "Pennsylvania" })""",
R"""({ "number": 1600, "street_type":"Avenue"})""",
R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
-#ifdef INCLUDE_FAILING_TESTS
- // TODO: Spaces should be permitted around enum values, but currently they fail to pass.
+ // Spaces are permitted around enum values
R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
-#endif
},
// Failing strings
{
@@ -923,18 +1211,16 @@ static void test_json_schema() {
test_schema(
"required + optional props each in original order",
// Schema
- R"""(
- {
- "properties": {
- "b": {"type": "string"},
- "a": {"type": "string"},
- "d": {"type": "string"},
- "c": {"type": "string"}
- },
- "required": ["a", "b"],
- "additionalProperties": false
- }
- )""",
+ R"""({
+ "properties": {
+ "b": {"type": "string"},
+ "a": {"type": "string"},
+ "d": {"type": "string"},
+ "c": {"type": "string"}
+ },
+ "required": ["a", "b"],
+ "additionalProperties": false
+ })""",
// Passing strings
{
R"""({"b": "foo", "a": "bar"})""",
@@ -954,8 +1240,7 @@ static void test_json_schema() {
test_schema(
"required props",
// Schema
- R"""(
- {
+ R"""({
"$schema": "https://json-schema.org/draft/2020-12/schema",
"$id": "https://example.com/product.schema.json",
"title": "Product",
@@ -1001,8 +1286,7 @@ static void test_json_schema() {
}
},
"required": [ "productId", "productName", "price" ]
- }
- )""",
+ })""",
// Passing strings
{
R"""({"productId": 1, "productName": "A green door", "price": 12.50})""",
diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp
index 87bc66b6..65486ac5 100755
--- a/tests/test-json-schema-to-grammar.cpp
+++ b/tests/test-json-schema-to-grammar.cpp
@@ -81,6 +81,232 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
};
test({
+ SUCCESS,
+ "min 0",
+ R"""({
+ "type": "integer",
+ "minimum": 0
+ })""",
+ R"""(
+ root ::= ([0] | [1-9] [0-9]{0,15}) space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "min 1",
+ R"""({
+ "type": "integer",
+ "minimum": 1
+ })""",
+ R"""(
+ root ::= ([1-9] [0-9]{0,15}) space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "min 3",
+ R"""({
+ "type": "integer",
+ "minimum": 3
+ })""",
+ R"""(
+ root ::= ([1-2] [0-9]{1,15} | [3-9] [0-9]{0,15}) space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "min 9",
+ R"""({
+ "type": "integer",
+ "minimum": 9
+ })""",
+ R"""(
+ root ::= ([1-8] [0-9]{1,15} | [9] [0-9]{0,15}) space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "min 10",
+ R"""({
+ "type": "integer",
+ "minimum": 10
+ })""",
+ R"""(
+ root ::= ([1] ([0-9]{1,15}) | [2-9] [0-9]{1,15}) space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "min 25",
+ R"""({
+ "type": "integer",
+ "minimum": 25
+ })""",
+ R"""(
+ root ::= ([1] [0-9]{2,15} | [2] ([0-4] [0-9]{1,14} | [5-9] [0-9]{0,14}) | [3-9] [0-9]{1,15}) space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "max 30",
+ R"""({
+ "type": "integer",
+ "maximum": 30
+ })""",
+ R"""(
+ root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-2] [0-9] | [3] "0")) space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "min -5",
+ R"""({
+ "type": "integer",
+ "minimum": -5
+ })""",
+ R"""(
+ root ::= ("-" ([0-5]) | [0] | [1-9] [0-9]{0,15}) space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "min -123",
+ R"""({
+ "type": "integer",
+ "minimum": -123
+ })""",
+ R"""(
+ root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0] | [1-9] [0-9]{0,15}) space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "max -5",
+ R"""({
+ "type": "integer",
+ "maximum": -5
+ })""",
+ R"""(
+ root ::= ("-" ([0-4] [0-9]{1,15} | [5-9] [0-9]{0,15})) space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "max 1",
+ R"""({
+ "type": "integer",
+ "maximum": 1
+ })""",
+ R"""(
+ root ::= ("-" [1-9] [0-9]{0,15} | [0-1]) space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "max 100",
+ R"""({
+ "type": "integer",
+ "maximum": 100
+ })""",
+ R"""(
+ root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-8] [0-9] | [9] [0-9]) | "100") space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "min 0 max 23",
+ R"""({
+ "type": "integer",
+ "minimum": 0,
+ "maximum": 23
+ })""",
+ R"""(
+ root ::= ([0-9] | ([1] [0-9] | [2] [0-3])) space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "min 15 max 300",
+ R"""({
+ "type": "integer",
+ "minimum": 15,
+ "maximum": 300
+ })""",
+ R"""(
+ root ::= (([1] ([5-9]) | [2-9] [0-9]) | ([1-2] [0-9]{2} | [3] "00")) space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "min 5 max 30",
+ R"""({
+ "type": "integer",
+ "minimum": 5,
+ "maximum": 30
+ })""",
+ R"""(
+ root ::= ([5-9] | ([1-2] [0-9] | [3] "0")) space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "min -123 max 42",
+ R"""({
+ "type": "integer",
+ "minimum": -123,
+ "maximum": 42
+ })""",
+ R"""(
+ root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0-9] | ([1-3] [0-9] | [4] [0-2])) space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "min -10 max 10",
+ R"""({
+ "type": "integer",
+ "minimum": -10,
+ "maximum": 10
+ })""",
+ R"""(
+ root ::= ("-" ([0-9] | "10") | [0-9] | "10") space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
FAILURE,
"unknown type",
R"""({
@@ -247,7 +473,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"const": "foo"
})""",
R"""(
- root ::= "\"foo\""
+ root ::= "\"foo\"" space
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@@ -259,7 +485,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"const": 123
})""",
R"""(
- root ::= "123"
+ root ::= "123" space
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
@@ -271,8 +497,40 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"enum": ["red", "amber", "green", null, 42, ["foo"]]
})""",
R"""(
- root ::= "\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]"
+ root ::= ("\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]") space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "string array",
+ R"""({
+ "type": "array",
+ "prefixItems": { "type": "string" }
+ })""",
+ R"""(
+ char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+ root ::= "[" space (string ("," space string)*)? "]" space
+ space ::= | " " | "\n" [ \t]{0,20}
+ string ::= "\"" char* "\"" space
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "nullable string array",
+ R"""({
+ "type": ["array", "null"],
+ "prefixItems": { "type": "string" }
+ })""",
+ R"""(
+ alternative-0 ::= "[" space (string ("," space string)*)? "]" space
+ char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+ null ::= "null" space
+ root ::= alternative-0 | null
space ::= | " " | "\n" [ \t]{0,20}
+ string ::= "\"" char* "\"" space
)"""
});
@@ -392,6 +650,44 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
test({
SUCCESS,
+ "min + max items with min + max values across zero",
+ R"""({
+ "items": {
+ "type": "integer",
+ "minimum": -12,
+ "maximum": 207
+ },
+ "minItems": 3,
+ "maxItems": 5
+ })""",
+ R"""(
+ item ::= ("-" ([0-9] | "1" [0-2]) | [0-9] | ([1-8] [0-9] | [9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7])) space
+ root ::= "[" space item ("," space item){2,4} "]" space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "min + max items with min + max values",
+ R"""({
+ "items": {
+ "type": "integer",
+ "minimum": 12,
+ "maximum": 207
+ },
+ "minItems": 3,
+ "maxItems": 5
+ })""",
+ R"""(
+ item ::= (([1] ([2-9]) | [2-9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7])) space
+ root ::= "[" space item ("," space item){2,4} "]" space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
"simple regexp",
R"""({
"type": "string",
@@ -552,13 +848,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
additional-kv ::= string ":" space additional-value
- additional-kvs ::= additional-kv ( "," space additional-kv )*
additional-value ::= "[" space (number ("," space number)*)? "]" space
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
- root ::= "{" space (additional-kvs )? "}" space
+ root ::= "{" space (additional-kv ( "," space additional-kv )* )? "}" space
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
@@ -635,13 +930,13 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
a-kv ::= "\"a\"" space ":" space number
- additional-kv ::= string ":" space string
- additional-kvs ::= additional-kv ( "," space additional-kv )*
+ additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["] space
+ additional-kv ::= additional-k ":" space string
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
- root ::= "{" space a-kv ( "," space ( additional-kvs ) )? "}" space
+ root ::= "{" space a-kv ( "," space ( additional-kv ( "," space additional-kv )* ) )? "}" space
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
@@ -659,16 +954,15 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
a-kv ::= "\"a\"" space ":" space number
- a-rest ::= additional-kvs
- additional-kv ::= string ":" space number
- additional-kvs ::= additional-kv ( "," space additional-kv )*
+ a-rest ::= ( "," space additional-kv )*
+ additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["] space
+ additional-kv ::= additional-k ":" space number
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
- root ::= "{" space (a-kv a-rest | additional-kvs )? "}" space
+ root ::= "{" space (a-kv a-rest | additional-kv ( "," space additional-kv )* )? "}" space
space ::= | " " | "\n" [ \t]{0,20}
- string ::= "\"" char* "\"" space
)"""
});
@@ -678,25 +972,100 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""({
"type": "object",
"properties": {
- "a": {"type": "number"},
- "b": {"type": "number"}
+ "and": {"type": "number"},
+ "also": {"type": "number"}
},
- "required": ["a"],
+ "required": ["and"],
"additionalProperties": {"type": "number"}
})""",
R"""(
- a-kv ::= "\"a\"" space ":" space number
- additional-kv ::= string ":" space number
- additional-kvs ::= additional-kv ( "," space additional-kv )*
- b-kv ::= "\"b\"" space ":" space number
- b-rest ::= additional-kvs
+ additional-k ::= ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
+ additional-kv ::= additional-k ":" space number
+ also-kv ::= "\"also\"" space ":" space number
+ also-rest ::= ( "," space additional-kv )*
+ and-kv ::= "\"and\"" space ":" space number
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
- root ::= "{" space a-kv ( "," space ( b-kv b-rest | additional-kvs ) )? "}" space
+ root ::= "{" space and-kv ( "," space ( also-kv also-rest | additional-kv ( "," space additional-kv )* ) )? "}" space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "optional props with empty name",
+ R"""({
+ "properties": {
+ "": {"type": "integer"},
+ "a": {"type": "integer"}
+ },
+ "additionalProperties": {"type": "integer"}
+ })""",
+ R"""(
+ -kv ::= "\"\"" space ":" space root
+ -rest ::= ( "," space a-kv )? a-rest
+ a-kv ::= "\"a\"" space ":" space integer
+ a-rest ::= ( "," space additional-kv )*
+ additional-k ::= ["] ( [a] char+ | [^"a] char* ) ["] space
+ additional-kv ::= additional-k ":" space integer
+ char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+ integer ::= ("-"? integral-part) space
+ integral-part ::= [0] | [1-9] [0-9]{0,15}
+ root ::= ("-"? integral-part) space
+ root0 ::= "{" space (-kv -rest | a-kv a-rest | additional-kv ( "," space additional-kv )* )? "}" space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "optional props with nested names",
+ R"""({
+ "properties": {
+ "a": {"type": "integer"},
+ "aa": {"type": "integer"}
+ },
+ "additionalProperties": {"type": "integer"}
+ })""",
+ R"""(
+ a-kv ::= "\"a\"" space ":" space integer
+ a-rest ::= ( "," space aa-kv )? aa-rest
+ aa-kv ::= "\"aa\"" space ":" space integer
+ aa-rest ::= ( "," space additional-kv )*
+ additional-k ::= ["] ( [a] ([a] char+ | [^"a] char*) | [^"a] char* )? ["] space
+ additional-kv ::= additional-k ":" space integer
+ char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+ integer ::= ("-"? integral-part) space
+ integral-part ::= [0] | [1-9] [0-9]{0,15}
+ root ::= "{" space (a-kv a-rest | aa-kv aa-rest | additional-kv ( "," space additional-kv )* )? "}" space
+ space ::= | " " | "\n" [ \t]{0,20}
+ )"""
+ });
+
+ test({
+ SUCCESS,
+ "optional props with common prefix",
+ R"""({
+ "properties": {
+ "ab": {"type": "integer"},
+ "ac": {"type": "integer"}
+ },
+ "additionalProperties": {"type": "integer"}
+ })""",
+ R"""(
+ ab-kv ::= "\"ab\"" space ":" space integer
+ ab-rest ::= ( "," space ac-kv )? ac-rest
+ ac-kv ::= "\"ac\"" space ":" space integer
+ ac-rest ::= ( "," space additional-kv )*
+ additional-k ::= ["] ( [a] ([b] char+ | [c] char+ | [^"bc] char*) | [^"a] char* )? ["] space
+ additional-kv ::= additional-k ":" space integer
+ char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
+ integer ::= ("-"? integral-part) space
+ integral-part ::= [0] | [1-9] [0-9]{0,15}
+ root ::= "{" space (ab-kv ab-rest | ac-kv ac-rest | additional-kv ( "," space additional-kv )* )? "}" space
space ::= | " " | "\n" [ \t]{0,20}
- string ::= "\"" char* "\"" space
)"""
});
@@ -870,26 +1239,30 @@ int main() {
}
});
- if (getenv("LLAMA_PYTHON_AVAILABLE") || (std::system("python -c \"import sys; exit(1) if sys.version_info < (3, 8) else print('Python version is sufficient')\"") == 0)) {
- test_all("Python", [](const TestCase & tc) {
- write("test-json-schema-input.tmp", tc.schema);
- tc.verify_status(std::system(
- "python ./examples/json_schema_to_grammar.py test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
- tc.verify(read("test-grammar-output.tmp"));
- });
+ if (getenv("LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR")) {
+ fprintf(stderr, "\033[33mWARNING: Skipping slow tests on emulator.\n\033[0m");
} else {
- fprintf(stderr, "\033[33mWARNING: Python not found (min version required is 3.8), skipping Python JSON schema -> grammar tests.\n\033[0m");
- }
+ if (getenv("LLAMA_PYTHON_AVAILABLE") || (std::system("python -c \"import sys; exit(1) if sys.version_info < (3, 8) else print('Python version is sufficient')\"") == 0)) {
+ test_all("Python", [](const TestCase & tc) {
+ write("test-json-schema-input.tmp", tc.schema);
+ tc.verify_status(std::system(
+ "python ./examples/json_schema_to_grammar.py test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
+ tc.verify(read("test-grammar-output.tmp"));
+ });
+ } else {
+ fprintf(stderr, "\033[33mWARNING: Python not found (min version required is 3.8), skipping Python JSON schema -> grammar tests.\n\033[0m");
+ }
- if (getenv("LLAMA_NODE_AVAILABLE") || (std::system("node --version") == 0)) {
- test_all("JavaScript", [](const TestCase & tc) {
- write("test-json-schema-input.tmp", tc.schema);
- tc.verify_status(std::system(
- "node ./tests/run-json-schema-to-grammar.mjs test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
- tc.verify(read("test-grammar-output.tmp"));
- });
- } else {
- fprintf(stderr, "\033[33mWARNING: Node not found, skipping JavaScript JSON schema -> grammar tests.\n\033[0m");
+ if (getenv("LLAMA_NODE_AVAILABLE") || (std::system("node --version") == 0)) {
+ test_all("JavaScript", [](const TestCase & tc) {
+ write("test-json-schema-input.tmp", tc.schema);
+ tc.verify_status(std::system(
+ "node ./tests/run-json-schema-to-grammar.mjs test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
+ tc.verify(read("test-grammar-output.tmp"));
+ });
+ } else {
+ fprintf(stderr, "\033[33mWARNING: Node not found, skipping JavaScript JSON schema -> grammar tests.\n\033[0m");
+ }
}
test_all("Check Expectations Validity", [](const TestCase & tc) {
diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp
index 27ca4d26..1f3a267b 100644
--- a/tests/test-llama-grammar.cpp
+++ b/tests/test-llama-grammar.cpp
@@ -2,10 +2,12 @@
#undef NDEBUG
#endif
-#include "llama.cpp" // TODO: not great
+#define LLAMA_API_INTERNAL
+#include "llama.h"
#include "grammar-parser.h"
#include <cassert>
+#include <stdexcept>
int main()
{
@@ -112,10 +114,14 @@ int main()
}
}
- llama_grammar *grammar = NULL;
+ llama_grammar * grammar = NULL;
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
- grammar = llama_grammar_init(
- grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+
+ grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+ if (grammar == nullptr)
+ {
+ throw std::runtime_error("Failed to initialize llama_grammar");
+ }
std::vector<std::vector<llama_grammar_element>> expected_stacks = {
{
@@ -168,7 +174,7 @@ int main()
}};
auto index = 0;
- for (auto stack : grammar->stacks)
+ for (auto stack : llama_grammar_get_stacks(grammar))
{
// compare stack to expected_stack
for (uint32_t i = 0; i < stack.size(); i++)
@@ -370,13 +376,13 @@ int main()
},
};
- std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[0], next_candidates);
+ std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[0], next_candidates);
std::vector<std::vector<llama_grammar_candidate>> all_rejects;
- for (std::size_t count = 0; count < grammar->stacks.size(); ++count)
+ for (std::size_t count = 0; count < llama_grammar_get_stacks(grammar).size(); ++count)
{
- rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[count], next_candidates);
+ rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[count], next_candidates);
all_rejects.push_back(rejects);
}
@@ -397,6 +403,6 @@ int main()
delete[] candidate.code_points;
candidate.code_points = nullptr;
}
- delete grammar;
+ llama_grammar_free(grammar);
return 0;
}
diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp
index cf4664b3..638b921b 100644
--- a/tests/test-quantize-fns.cpp
+++ b/tests/test-quantize-fns.cpp
@@ -69,7 +69,7 @@ static float reference_quantization_error(ggml_type_traits_t & qfns, size_t test
qfns.from_float(test_data, tmp_q.data(), test_size);
qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
- qfns.from_float_reference(test_data, tmp_q.data(), test_size);
+ qfns.from_float_ref(test_data, tmp_q.data(), test_size);
qfns.to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
@@ -92,7 +92,7 @@ static float dot_product_error(
auto vdot = ggml_internal_get_type_traits(qfns.vec_dot_type);
- qfns.from_float_reference(test_data1, tmp_q1.data(), test_size);
+ qfns.from_float_ref(test_data1, tmp_q1.data(), test_size);
vdot.from_float(test_data2, tmp_q2.data(), test_size);
float result = INFINITY;
diff --git a/tests/test-quantize-perf.cpp b/tests/test-quantize-perf.cpp
index 48d9fae3..24e06605 100644
--- a/tests/test-quantize-perf.cpp
+++ b/tests/test-quantize-perf.cpp
@@ -285,7 +285,7 @@ int main(int argc, char * argv[]) {
for (size_t size : params.test_sizes) {
printf(" %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
auto quantize_fn = [&](void) -> float {
- qfns.from_float_reference(test_data1, test_q1, size);
+ qfns.from_float_ref(test_data1, test_q1, size);
return test_q1[0];
};
size_t quantized_size = ggml_row_size(type, size);
diff --git a/tests/test-rope.cpp b/tests/test-rope.cpp
index f0895ffa..8159e276 100644
--- a/tests/test-rope.cpp
+++ b/tests/test-rope.cpp
@@ -218,4 +218,3 @@ int main(int /*argc*/, const char ** /*argv*/) {
return 0;
}
-
diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp
index d478f104..d3d21331 100644
--- a/tests/test-tokenizer-0.cpp
+++ b/tests/test-tokenizer-0.cpp
@@ -195,11 +195,11 @@ int main(int argc, char **argv) {
const bool add_special = false;
for (const auto & test_kv : k_tests) {
- const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special);
+ const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, false);
printf("\n");
printf("src: '%s'\n", test_kv.first.c_str());
- printf("res: '%s'\n", llama_detokenize_bpe(ctx, res).c_str());
+ printf("res: '%s'\n", llama_detokenize(ctx, res).c_str());
printf("tok: ");
for (const auto & tok : res) {
printf("%d ", tok);
@@ -216,8 +216,8 @@ int main(int argc, char **argv) {
if (!correct) {
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
- llama_detokenize_bpe(ctx, res).c_str(),
- llama_detokenize_bpe(ctx, test_kv.second).c_str());
+ llama_detokenize(ctx, res).c_str(),
+ llama_detokenize(ctx, test_kv.second).c_str());
fprintf(stderr, "%s : expected tokens: ", __func__);
for (const auto & t : test_kv.second) {
fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
@@ -253,7 +253,7 @@ int main(int argc, char **argv) {
{
const auto t_start = ggml_time_us();
- res = llama_tokenize(ctx, text, add_special);
+ res = llama_tokenize(ctx, text, add_special, false);
const auto t_end = ggml_time_us();
@@ -272,7 +272,7 @@ int main(int argc, char **argv) {
}
for (const auto & tok : res) {
- //ofs << tok << " '" << string_strip(llama_detokenize_bpe(ctx, std::vector<int>{tok})) << "'" << std::endl;
+ //ofs << tok << " '" << string_strip(llama_detokenize(ctx, std::vector<int>{tok})) << "'" << std::endl;
ofs << tok << "\n";
}
}
diff --git a/tests/test-tokenizer-1-bpe.cpp b/tests/test-tokenizer-1-bpe.cpp
index 209a04ad..9498387e 100644
--- a/tests/test-tokenizer-1-bpe.cpp
+++ b/tests/test-tokenizer-1-bpe.cpp
@@ -11,6 +11,7 @@
#include <string>
#include <thread>
#include <vector>
+#include <atomic>
int main(int argc, char **argv) {
if (argc < 2 || argc > 3) {
@@ -63,7 +64,10 @@ int main(int argc, char **argv) {
}
}
- GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_BPE);
+ //GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_BPE);
+ if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) {
+ return 99;
+ }
#ifdef _WIN32
// We need this for unicode console support
@@ -74,7 +78,7 @@ int main(int argc, char **argv) {
const int n_vocab = llama_n_vocab(model);
for (int i = 0; i < n_vocab; ++i) {
- std::string str = llama_detokenize_bpe(ctx, std::vector<int>(1, i));
+ std::string str = llama_detokenize(ctx, std::vector<int>(1, i));
try {
auto cps = unicode_cpts_from_utf8(str);
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false, true);
@@ -90,7 +94,7 @@ int main(int argc, char **argv) {
fprintf(stderr, "]\n");
return 2;
}
- std::string check = llama_detokenize_bpe(ctx, tokens);
+ std::string check = llama_detokenize(ctx, tokens);
if (check != str) {
fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
__func__, i, str.c_str(), str.length(), check.c_str(), check.length());
@@ -108,26 +112,23 @@ int main(int argc, char **argv) {
std::vector<std::thread> threads(nthread);
+ std::atomic_int errcode = {};
+
for (int i = 0; i < nthread; ++i) {
- threads[i] = std::thread([i, nthread, ctx]() {
- for (uint32_t cp = i; cp < 0x0010ffff; cp += nthread) {
- if (!( // NOLINT
- (cp < 0x03 || cp > 0x05) && cp != 0x0b && cp != 0x11 &&
- (cp < 0x13 || cp > 0x17) && cp != 0x19 &&
- (cp < 0x1c || cp > 0x1e) &&
- (cp < 0xd800 || cp > 0xdfff) &&
- (cp < 0x00040000 || cp >= 0x000e0000)
- )) {
+ threads[i] = std::thread([i, nthread, ctx, &errcode]() {
+ for (uint32_t cp = i; !errcode && cp < 0x00110000; cp += nthread) {
+ if ((0x0000D800 <= cp && cp <= 0x0000DFFF) || // surrogates \p{Cs}
+ (0x00040000 <= cp && cp <= 0x000E0000)) { // undefined \p{Cn}
continue;
}
std::string str = unicode_cpt_to_utf8(cp);
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
- std::string check = llama_detokenize_bpe(ctx, tokens);
+ std::string check = llama_detokenize(ctx, tokens);
if (cp != 9601 && str != check) {
- fprintf(stderr, "error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
+ fprintf(stderr, "error: codepoint 0x%x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
cp, check.c_str(), check.length(), str.c_str(), str.length());
- std::exit(3);
+ errcode = 3;
}
}
});
@@ -136,6 +137,10 @@ int main(int argc, char **argv) {
for (auto & t : threads) {
t.join();
}
+
+ if (errcode) {
+ return errcode;
+ }
}
llama_free_model(model);
diff --git a/tests/test-tokenizer-1-spm.cpp b/tests/test-tokenizer-1-spm.cpp
index ac2333dd..7ca9e2ca 100644
--- a/tests/test-tokenizer-1-spm.cpp
+++ b/tests/test-tokenizer-1-spm.cpp
@@ -11,6 +11,7 @@
#include <string>
#include <thread>
#include <vector>
+#include <atomic>
int main(int argc, char ** argv) {
if (argc < 2) {
@@ -51,7 +52,10 @@ int main(int argc, char ** argv) {
}
}
- GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
+ //GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
+ if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_SPM) {
+ return 99;
+ }
#ifdef _WIN32
// We need this for unicode console support
@@ -62,9 +66,9 @@ int main(int argc, char ** argv) {
const int n_vocab = llama_n_vocab(model);
for (int i = 0; i < n_vocab; ++i) {
- std::string str = llama_detokenize_spm(ctx, std::vector<int>(1, i));
- std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
- std::string check = llama_detokenize_spm(ctx, tokens);
+ std::string str = llama_detokenize(ctx, std::vector<int>(1, i), true);
+ std::vector<llama_token> tokens = llama_tokenize(ctx, str, false, true);
+ std::string check = llama_detokenize(ctx, tokens);
if (check != str) {
fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
__func__, i, str.c_str(), str.length(), check.c_str(), check.length());
@@ -78,20 +82,23 @@ int main(int argc, char ** argv) {
std::vector<std::thread> threads(nthread);
+ std::atomic_int errcode = {};
+
for (int i = 0; i < nthread; ++i) {
- threads[i] = std::thread([i, nthread, ctx]() {
- for (uint32_t cp = i; cp < 0x0010ffff; cp += nthread) {
- if (cp >= 0xd800 && cp <= 0xdfff) {
+ threads[i] = std::thread([i, nthread, ctx, &errcode]() {
+ for (uint32_t cp = i; !errcode && cp < 0x00110000; cp += nthread) {
+ if ((0x0000D800 <= cp && cp <= 0x0000DFFF) || // surrogates \p{Cs}
+ (0x00040000 <= cp && cp <= 0x000E0000)) { // undefined \p{Cn}
continue;
}
std::string str = unicode_cpt_to_utf8(cp);
- std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
- std::string check = llama_detokenize_spm(ctx, tokens);
+ std::vector<llama_token> tokens = llama_tokenize(ctx, str, false, true);
+ std::string check = llama_detokenize(ctx, tokens);
if (cp != 9601 && str != check) {
- fprintf(stderr, "error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
+ fprintf(stderr, "error: codepoint 0x%x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
cp, check.c_str(), check.length(), str.c_str(), str.length());
- std::exit(3);
+ errcode = 3;
}
}
});
@@ -100,6 +107,10 @@ int main(int argc, char ** argv) {
for (auto & t : threads) {
t.join();
}
+
+ if(errcode) {
+ return errcode;
+ }
}
llama_free_model(model);
diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py
index a07c52fb..9ebe6c89 100644
--- a/tests/test-tokenizer-random.py
+++ b/tests/test-tokenizer-random.py
@@ -6,6 +6,8 @@
# python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
#
+from __future__ import annotations
+
import time
import logging
import argparse
@@ -13,10 +15,12 @@ import subprocess
import random
import unicodedata
-from typing import Callable, Iterator
+from pathlib import Path
+from typing import Any, Iterator, cast
+from typing_extensions import Buffer
import cffi
-from transformers import AutoTokenizer
+from transformers import AutoTokenizer, PreTrainedTokenizer
logger = logging.getLogger("test-tokenizer-random")
@@ -24,17 +28,20 @@ logger = logging.getLogger("test-tokenizer-random")
class LibLlama:
- DEFAULT_PATH_LLAMA_H = "./llama.h"
- DEFAULT_PATH_LIBLLAMA = "./build/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON
+ DEFAULT_PATH_LLAMA_H = "./include/llama.h"
+ DEFAULT_PATH_INCLUDES = ["./ggml/include/", "./include/"]
+ DEFAULT_PATH_LIBLLAMA = "./build/src/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON
- def __init__(self, path_llama_h: str = None, path_libllama: str = None):
+ def __init__(self, path_llama_h: str | None = None, path_includes: list[str] = [], path_libllama: str | None = None):
path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
+ path_includes = path_includes or self.DEFAULT_PATH_INCLUDES
path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA
- (self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_libllama)
+ (self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_includes, path_libllama)
self.lib.llama_backend_init()
- def _load_libllama_cffi(self, path_llama_h: str, path_libllama: str):
- cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)=", path_llama_h]
+ def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str) -> tuple[cffi.FFI, Any]:
+ cmd = ["gcc", "-O0", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="]
+ cmd += ["-I" + path for path in path_includes] + [path_llama_h]
res = subprocess.run(cmd, stdout=subprocess.PIPE)
assert (res.returncode == 0)
source = res.stdout.decode()
@@ -65,7 +72,7 @@ class LibLlama:
class LibLlamaModel:
def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
- self.lib = libllama.lib
+ self.lib: Any = libllama.lib
self.ffi = libllama.ffi
if isinstance(mparams, dict):
mparams = libllama.model_default_params(**mparams)
@@ -79,6 +86,7 @@ class LibLlamaModel:
raise RuntimeError("error: failed to create context for model '%s'" % path_model)
n_tokens_max = self.lib.llama_n_ctx(self.ctx)
self.token_ids = self.ffi.new("llama_token[]", n_tokens_max)
+ self.text_buff = self.ffi.new("uint8_t[]", 1024)
def free(self):
if self.ctx:
@@ -89,14 +97,78 @@ class LibLlamaModel:
self.model = None
self.lib = None
- def tokenize(self, text: str, n_tokens_max: int = 0, add_special: bool = False, parse_special: bool = False) -> list[int]:
- n_tokens_max = n_tokens_max if n_tokens_max > 0 else len(self.token_ids)
- text = text.encode("utf-8")
- num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, n_tokens_max, add_special, parse_special)
- if num < 0:
- return []
+ def tokenize(self, text: str, add_special: bool = False, parse_special: bool = False) -> list[int]:
+ encoded_text: bytes = text.encode("utf-8")
+ num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special)
+ while num < 0 and len(self.token_ids) < (16 << 20):
+ self.token_ids = self.ffi.new("llama_token[]", -2 * num)
+ num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special)
return list(self.token_ids[0:num])
+ def detokenize(self, ids: list[int], remove_special: bool = False, unparse_special: bool = False) -> str:
+ if len(self.token_ids) < len(ids):
+ self.token_ids = self.ffi.new("llama_token[]", 2 * len(ids))
+ for i, id in enumerate(ids):
+ self.token_ids[i] = id
+ num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
+ while num < 0 and len(self.text_buff) < (16 << 20):
+ self.text_buff = self.ffi.new("uint8_t[]", -2 * num)
+ num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
+ return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
+
+
+class Tokenizer:
+
+ def encode(self, text: str) -> list[int]:
+ raise NotImplementedError
+
+ def decode(self, ids: list[int]) -> str:
+ raise NotImplementedError
+
+
+class TokenizerGroundtruth (Tokenizer):
+
+ def __init__(self, dir_tokenizer: str):
+ self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer)
+ # guess BOS and EOS
+ ids = self.encode("a")
+ assert 1 <= len(ids) <= 3
+ add_bos_token = len(ids) > 1 and self.model.bos_token_id == ids[0]
+ add_eos_token = len(ids) > 1 and self.model.eos_token_id == ids[-1]
+ self.add_bos_token = getattr(self.model, "add_bos_token", add_bos_token)
+ self.add_eos_token = getattr(self.model, "add_eos_token", add_eos_token)
+ # build vocab
+ tokens = list(self.model.get_vocab().values())
+ self.vocab = self.model.batch_decode(tokens, skip_special_tokens=True)
+ self.vocab = list(sorted(self.vocab))
+ # tokens and lists
+ self.special_tokens = list(self.model.all_special_tokens)
+ self.added_tokens = self.model.batch_decode(self.model.added_tokens_encoder.values(), skip_special_tokens=False)
+ self.bos_token = self.model.bos_token
+ self.eos_token = self.model.eos_token
+
+ def encode(self, text: str) -> list[int]:
+ return self.model.encode(text, add_special_tokens=True)
+
+ def decode(self, ids: list[int]) -> str:
+ return self.model.decode(ids, skip_special_tokens=False)
+
+
+class TokenizerLlamaCpp (Tokenizer):
+
+ libllama: LibLlama | None = None
+
+ def __init__(self, vocab_file: str):
+ if not self.libllama:
+ self.libllama = LibLlama()
+ self.model = LibLlamaModel(self.libllama, vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
+
+ def encode(self, text: str) -> list[int]:
+ return self.model.tokenize(text, add_special=True, parse_special=True)
+
+ def decode(self, ids: list[int]) -> str:
+ return self.model.detokenize(ids, remove_special=False, unparse_special=True)
+
def generator_custom_text() -> Iterator[str]:
"""General tests"""
@@ -160,24 +232,54 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
'a\na', # bert fail
'"`', # falcon
' \u2e4e', # falcon
+ '\n\x0b ', # falcon
'a\xa0\xa0\x00b', # jina-v2-es
'one <mask>', # jina-v2-es <mask> lstrip=true
'a </s> b', # rstrip phi-3
'a <mask> b', # lstrip jina-v2
'\xa0aC', # deepseek
+ '\u2029 \uA3E4', # deepseek-llm
+ "a ?",
+ 'å', # mpt
+ '\U000ac517', # utf-8 encode error, falcon
+ '\U000522f4', # utf-8 encode error, starcoder
+ "<s><s><unk><s>a<s>b<s>c<unk>d<unk></s>",
+ "<s> <s> <unk><s>a<s>b<s>c<unk>d<unk></s>",
]
-def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
+def generator_vocab_words(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
"""Brute force check all vocab words"""
- yield from vocab
-
-
-def generator_added_lr_strip(tokenizer) -> Iterator[str]:
- WHITESPACES = ["", " ", " ", " "]
- special_tokens = list(tokenizer.all_special_tokens)
- added_tokens = list(tokenizer.added_tokens_encoder)
- all_tokens = list(sorted(set(special_tokens + added_tokens)))
+ yield from tokenizer.vocab
+
+
+def generator_ascii_lr_strip() -> Iterator[str]:
+ WHITESPACES = ["", " ", " "]
+ CHARACTERS = list(chr(i) for i in range(1, 0x80)) + [""]
+ for char1 in CHARACTERS:
+ for char2 in CHARACTERS:
+ for lstrip in WHITESPACES:
+ for rstrip in WHITESPACES:
+ yield lstrip + char1 + char2 + rstrip
+ yield lstrip + char1 + rstrip + char2
+ yield char1 + lstrip + char2 + rstrip
+
+
+def generator_apostrophe() -> Iterator[str]:
+ WHITESPACES = ["", " ", " "]
+ CHARACTERS = list(chr(i) for i in range(1, 0x80)) + [""]
+ for char1 in CHARACTERS:
+ for char2 in CHARACTERS:
+ for lstrip in WHITESPACES:
+ for rstrip in WHITESPACES:
+ yield char1 + lstrip + "'" + rstrip + char2
+ yield char1 + char2 + lstrip + "'" + rstrip + "z"
+ yield "a" + lstrip + "'" + rstrip + char1 + char2
+
+
+def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
+ WHITESPACES = ["", " ", " ", "\n", "\r\n", "\n\n", "\t", "\t\t"]
+ all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens)))
for token in all_tokens:
for lstrip in WHITESPACES:
for rstrip in WHITESPACES:
@@ -187,11 +289,9 @@ def generator_added_lr_strip(tokenizer) -> Iterator[str]:
yield "a" + lstrip + token + rstrip + "z"
-def generator_random_added_tokens(tokenizer, iterations=100) -> Iterator[str]:
- special_tokens = list(tokenizer.all_special_tokens)
- added_tokens = list(tokenizer.added_tokens_encoder)
- separations = [" ", "\n", "\t", "-", "!", "one", "1", "<s>", "</s>"]
- all_tokens = list(sorted(set(special_tokens + added_tokens + separations)))
+def generator_random_added_tokens(tokenizer: TokenizerGroundtruth, iterations=100) -> Iterator[str]:
+ separations = [" ", "\n", "\t", "-", "!", "one", "1", "<s>", "</s>"]
+ all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens + separations)))
rand = random.Random()
for m in range(iterations):
rand.seed(m)
@@ -242,13 +342,13 @@ def generator_unicodes() -> Iterator[str]:
def _valid(cpt):
if cpt >= 0x30000: # unassigned and supplement­ary
return False
- if 0x00D800 <= cpt <= 0x00F8FF: # Surrogates
- return False
- if unicodedata.category(chr(cpt)) == "Cn":
+ # if cpt == 0x2029: # deepseek-llm
+ # return False
+ if unicodedata.category(chr(cpt)) in ("Cn", "Cs", "Co"): # undefined, surrogates, private
return False
return True
- characters = [chr(cpt) for cpt in range(1, MAX_CODEPOINTS) if _valid(cpt)]
+ characters = [chr(cpt) for cpt in range(0, MAX_CODEPOINTS) if _valid(cpt)]
yield from characters
@@ -273,11 +373,11 @@ def generator_random_unicodes(iterations=100) -> Iterator[str]:
yield "".join(text)
-def generator_random_vocab_chars(vocab: list[str], iterations=100) -> Iterator[str]:
+def generator_random_vocab_chars(tokenizer: TokenizerGroundtruth, iterations=100) -> Iterator[str]:
"""Brute force random text with vocab characters"""
vocab_chars = set()
- for word in vocab:
+ for word in tokenizer.vocab:
vocab_chars.update(word)
vocab_chars = list(sorted(vocab_chars))
@@ -288,10 +388,10 @@ def generator_random_vocab_chars(vocab: list[str], iterations=100) -> Iterator[s
yield "".join(text)
-def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[str]:
+def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100) -> Iterator[str]:
"""Brute force random text from vocab words"""
- vocab = [w.strip() for w in vocab]
+ vocab = [w.strip() for w in tokenizer.vocab]
yield from vocab
rand = random.Random()
@@ -307,9 +407,9 @@ def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[s
yield "".join(text)
-def compare_tokenizers(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
+def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
- def find_first_mismatch(ids1: list[int], ids2: list[int]):
+ def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
for i, (a, b) in enumerate(zip(ids1, ids2)):
if a != b:
return i
@@ -317,115 +417,150 @@ def compare_tokenizers(func_tokenize1: Callable, func_tokenize2: Callable, gener
return -1
return min(len(ids1), len(ids2))
- t_tokenizer1 = 0
- t_tokenizer2 = 0
+ def check_detokenizer(text: str, text1: str, text2: str) -> bool:
+ if text1 == text2: # equal to TokenizerGroundtruth?
+ return True
+ # equal to source text?
+ if tokenizer1.add_bos_token: # remove BOS
+ if text2.startswith(tokenizer1.bos_token):
+ text2 = text2[len(tokenizer1.bos_token):]
+ if tokenizer1.add_eos_token: # remove EOS
+ if text2.endswith(tokenizer1.eos_token):
+ text2 = text2[:-len(tokenizer1.eos_token)]
+ return text == text2
+
+ t_encode1 = 0
+ t_encode2 = 0
+ t_decode1 = 0
+ t_decode2 = 0
t_start = time.perf_counter()
- num_errors = 10
+ encode_errors = 0
+ decode_errors = 0
+ MAX_ERRORS = 10
- logger.info("%s: %s" % (generator.__name__, "ini"))
+ logger.info("%s: %s" % (generator.__qualname__, "ini"))
for text in generator:
+ # print(repr(text), text.encode())
# print(repr(text), hex(ord(text[0])), text.encode())
t0 = time.perf_counter()
- ids1 = func_tokenize1(text)
+ ids1 = tokenizer1.encode(text)
t1 = time.perf_counter()
- ids2 = func_tokenize2(text)
+ ids2 = tokenizer2.encode(text)
t2 = time.perf_counter()
- t_tokenizer1 += t1 - t0
- t_tokenizer2 += t2 - t1
- if ids1 != ids2:
+ text1 = tokenizer1.decode(ids1)
+ t3 = time.perf_counter()
+ text2 = tokenizer2.decode(ids1)
+ t4 = time.perf_counter()
+ t_encode1 += t1 - t0
+ t_encode2 += t2 - t1
+ t_decode1 += t3 - t2
+ t_decode2 += t4 - t3
+ if encode_errors < MAX_ERRORS and ids1 != ids2:
i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
- logger.error(" TokenIDs: " + str(ids1))
- logger.error(" Expected: " + str(ids2))
+ logger.error(" Expected: " + str(ids1))
+ logger.error(" Result: " + str(ids2))
+ encode_errors += 1
+ logger.error(f" {encode_errors=}")
+ if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
+ i = find_first_mismatch(text1, text2)
+ text1 = list(text1[max(0, i - 2) : i + 5 + 1])
+ text2 = list(text2[max(0, i - 2) : i + 5 + 1])
+ logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
+ logger.error(" Result: " + " ".join(hex(ord(x)) for x in text2))
+ decode_errors += 1
+ logger.error(f" {decode_errors=}")
+ if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
+ logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
# raise Exception()
- num_errors += 1
- if num_errors > 10:
- break
+ break
t_total = time.perf_counter() - t_start
- logger.info("%s: end, tok1: %.3f tok2: %.3f total: %.3f" % (generator.__name__, t_tokenizer1, t_tokenizer2, t_total))
+ logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
-def main(argv: list[str] = None):
+def main(argv: list[str] | None = None):
parser = argparse.ArgumentParser()
- parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
- parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
+ parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
+ parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
args = parser.parse_args(argv)
logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO)
logger.info(f"VOCABFILE: '{args.vocab_file}'")
- model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
- tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
-
- def func_tokenize1(text: str):
- return model.tokenize(text, add_special=True, parse_special=True)
-
- def func_tokenize2(text: str):
- return tokenizer.encode(text, add_special_tokens=True)
+ tokenizer1 = TokenizerGroundtruth(args.dir_tokenizer)
+ tokenizer2 = TokenizerLlamaCpp(args.vocab_file)
- ids = func_tokenize2("a")
- assert 1 <= len(ids) <= 3
- add_bos_token = len(ids) > 1 and tokenizer.bos_token_id == ids[0]
- add_eos_token = len(ids) > 1 and tokenizer.eos_token_id == ids[-1]
- tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", add_bos_token)
- tokenizer.add_eos_token = getattr(tokenizer, "add_eos_token", add_eos_token)
+ # compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text())
+ # compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases())
+ compare_tokenizers(tokenizer1, tokenizer2, generator_ascii_lr_strip())
+ compare_tokenizers(tokenizer1, tokenizer2, generator_apostrophe())
+ compare_tokenizers(tokenizer1, tokenizer2, generator_unicodes())
+ compare_tokenizers(tokenizer1, tokenizer2, generator_vocab_words(tokenizer1))
+ compare_tokenizers(tokenizer1, tokenizer2, generator_added_lr_strip(tokenizer1))
+ # compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000))
+ # compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000))
+ # compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000))
+ # compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_chars(tokenizer1, 10_000))
+ # compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_words(tokenizer1, 5_000))
- vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
-
- compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text())
- compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
- compare_tokenizers(func_tokenize1, func_tokenize2, generator_unicodes())
- compare_tokenizers(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
- compare_tokenizers(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
- compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
- compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
- compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_unicodes(10_000))
- compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
- compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
-
- model.free()
+ tokenizer2.model.free()
if __name__ == "__main__":
# main()
+ if True:
+ logging.basicConfig(
+ level = logging.DEBUG,
+ format = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s %(message)s",
+ datefmt = "%Y-%m-%d %H:%M:%S",
+ filename = logger.name + ".log",
+ filemode = "a"
+ )
logging.basicConfig(
level = logging.DEBUG,
- format = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s %(message)s",
- datefmt = "%Y-%m-%d %H:%M:%S",
- filename = logger.name + ".log",
- filemode = "a"
+ format = "%(levelname)s %(message)s",
)
- path_tokenizers = "./models/tokenizers/"
+ path_tokenizers = Path("./models/tokenizers/")
path_vocab_format = "./models/ggml-vocab-%s.gguf"
- # import os
- # tokenizers = os.listdir(path_tokenizers)
tokenizers = [
- # "llama-spm", # SPM
- # "phi-3", # SPM
- # "bert-bge", # WPM
- # "jina-v2-en", # WPM
- "gpt-2", # BPE
+ "llama-spm", # SPM
+ "phi-3", # SPM
+ "gemma", # SPM
+ "gemma-2", # SPM
+ "baichuan", # SPM
+ "bert-bge", # WPM
+ "jina-v2-en", # WPM
"llama-bpe", # BPE
+ "phi-2", # BPE
+ "deepseek-llm", # BPE
+ "deepseek-coder", # BPE
"falcon", # BPE
+ "mpt", # BPE
"starcoder", # BPE
+ "gpt-2", # BPE
+ "stablelm2", # BPE
+ "refact", # BPE
+ "qwen2", # BPE
+ "olmo", # BPE
"jina-v2-es", # BPE
"jina-v2-de", # BPE
- "jina-v2-code", # BPE
"smaug-bpe", # BPE
- "phi-2", # BPE
- "deepseek-coder", # BPE
- "deepseek-llm", # BPE
+ "poro-chat", # BPE
+ "jina-v2-code", # BPE
+ "viking", # BPE
+ "jais", # BPE
]
+ logger.info("=" * 50)
for tokenizer in tokenizers:
- logger.info("=" * 50)
+ logger.info("-" * 50)
logger.info(f"TOKENIZER: '{tokenizer}'")
- vocab_file = path_vocab_format % tokenizer
- dir_tokenizer = path_tokenizers + "/" + tokenizer
- main([vocab_file, dir_tokenizer, "--verbose"])
+ vocab_file = Path(path_vocab_format % tokenizer)
+ dir_tokenizer = path_tokenizers / tokenizer
+ main([str(vocab_file), str(dir_tokenizer), "--verbose"])