summaryrefslogtreecommitdiff
path: root/tests/test-tokenizer-0.cpp
blob: 7e9ac9188d5c5b15f2e90c1e321ab51c1ac0e328 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include "llama.h"
#include "common.h"

#include <cstdio>
#include <string>
#include <map>
#include <vector>

static std::string unescape_whitespace(llama_context* ctx, const std::vector<llama_token>& tokens) {
    std::string result;
    for (size_t i = 0; i < tokens.size(); ++i) {
        result += llama_token_to_str(ctx, tokens[i]);
    }
    return result;
}

static const std::map<std::string, std::vector<llama_token>> & k_tests() {
    static std::map<std::string, std::vector<llama_token>> _k_tests = {
        { " ",                      {1,    259, }, },
        { "  ",                     { 1,    1678, }, },
        { "   ",                    { 1,     268, }, },
        { "\t",                     { 1,    29871,   12, }, },
        { "\n",                     { 1,    29871,   13, }, },
        { "\t\n",                   { 1,    29871,   12,     13, }, },
        { "Hello world",            { 1,  15043,   3186, }, },
        { " Hello world",           { 1,  29871,  15043,   3186, }, },
        { "Hello World",            { 1,  15043,   2787, }, },
        { " Hello World",           { 1,  29871,  15043,   2787, }, },
        { " Hello World!",          { 1,  29871,  15043,   2787,  29991, }, },
        { " this is πŸ¦™.cpp",        { 1,  29871,    445,    338,  29871,    243,    162,    169,    156,  29889,   8223, }, },
        { "w048 7tuijk dsdfhu",     { 1,    281,  29900,  29946,  29947,  29871,  29955,   9161,  13535,  18031,   2176,   6905, }, },
        { "Π½Π΅Ρ‰ΠΎ Π½Π° Π‘ΡŠΠ»Π³Π°Ρ€ΡΠΊΠΈ",      { 1,   1538,   4851,    665,   1386,  29713,   1305, }, },
        { "αž€αžΆαž“αŸ‹αžαŸ‚αž–αž·αžŸαŸαžŸαž’αžΆαž…αžαž›αž…αŸαž‰",   { 1,  29871,  31849,  31324,  31934,    228,    162,    142,    228,    161,
                                     146,    228,    162,    133,    228,    161,    153,    228,    161,    186,
                                     31708,    228,    162,    132,  31708,    228,    161,    165,  31324,    228,
                                     161,    136,    228,    161,    132,    228,    161,    158,    228,    161,
                                     136,    228,    162,    132,    228,    161,    140, }, },
        { "πŸš€ (normal) πŸ˜Άβ€πŸŒ«οΈ (multiple emojis concatenated) βœ… (only emoji that has its own token)",
            { 1,  29871,    243,    162,    157,    131,    313,   8945,  29897,  29871,
                243,    162,    155,    185,  30722,    243,    162,    143,    174,  30598,
                313,  20787,    953,   3848,    275,  16125,    630,  29897,  29871,  31681,
                313,   6194,    953,  29877,   2397,    393,    756,    967,   1914,   5993,  29897, }, },
        { "Hello",                  { 1,    15043 }, },
        { " Hello",                 { 1,    29871,  15043 }, },
        { "  Hello",                { 1,    259,    15043 }, },
        { "   Hello",               { 1,    1678,   15043 }, },
        { "    Hello",              { 1,    268,    15043 }, },
        { "    Hello\n    Hello",   { 1,    268,    15043,  13,     1678,   15043 }, },
    };

    return _k_tests;
}

int main(int argc, char **argv) {
    if (argc < 2) {
        fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
        return 1;
    }

    const std::string fname = argv[1];

    fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());

    llama_model * model;
    llama_context * ctx;

    llama_backend_init(false);

    // load the vocab
    {
        auto lparams = llama_context_default_params();

        lparams.vocab_only = true;

        model = llama_load_model_from_file(fname.c_str(), lparams);

        if (model == NULL) {
            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
            return 1;
        }

        ctx = llama_new_context_with_model(model, lparams);

        if (ctx == NULL) {
            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
            llama_free_model(model);
            return 1;
        }
    }

    const int n_vocab = llama_n_vocab(ctx);

    if (n_vocab != 32000) {
        fprintf(stderr, "%s : expected 32000 tokens, got %d\n", __func__, n_vocab);
        llama_free_model(model);
        llama_free(ctx);
        return 2;
    }

    bool success = true;

    for (const auto & test_kv : k_tests()) {
        // Add a space in front of the first character to match OG llama tokenizer behavior
        std::vector<llama_token> res = llama_tokenize(ctx, " " + test_kv.first, true);
        fprintf(stderr, "%s : '%s' tokenized to '%s'\n",
            __func__, test_kv.first.c_str(), unescape_whitespace(ctx, res).c_str());

        bool correct = res.size() == test_kv.second.size();

        for (int i = 0; i < (int) res.size() && correct; ++i) {
            if (res[i] != test_kv.second[i]) {
                correct = false;
            }
        }

        if (!correct) {
            fprintf(stderr, "%s : failed test:    '%s'\n", __func__, test_kv.first.c_str());
            fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
                unescape_whitespace(ctx, res).c_str(), unescape_whitespace(ctx, test_kv.second).c_str());
            fprintf(stderr, "%s : expected tokens: ", __func__);
            for (const auto & t : test_kv.second) {
                fprintf(stderr, "%6d, ", t);
            }
            fprintf(stderr, "\n");
            fprintf(stderr, "%s : got tokens:      ", __func__);
            for (const auto & t : res) {
                fprintf(stderr, "%6d, ", t);
            }
            fprintf(stderr, "\n");

            success = false;
        }
    }

    llama_free_model(model);
    llama_free(ctx);

    llama_backend_free();

    return success ? 0 : 3;
}