| 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
 | #define LLAMA_API_INTERNAL
#include "grammar-parser.h"
#include "ggml.h"
#include "llama.h"
#include "unicode.h"
#include <cstdio>
#include <cstdlib>
#include <string>
#include <vector>
static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
    auto decoded = decode_utf8(input_str, {});
    const auto & code_points = decoded.first;
    size_t pos = 0;
    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
        auto prev_stacks = grammar->stacks;
        llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
        if (grammar->stacks.empty()) {
            error_pos = pos;
            error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
            grammar->stacks = prev_stacks;
            return false;
        }
        ++pos;
    }
    for (const auto & stack : grammar->stacks) {
        if (stack.empty()) {
            return true;
        }
    }
    error_pos = pos;
    error_msg = "Unexpected end of input";
    return false;
}
static void print_error_message(const std::string & input_str, size_t error_pos, const std::string & error_msg) {
    fprintf(stdout, "Input string is invalid according to the grammar.\n");
    fprintf(stdout, "Error: %s at position %zu\n", error_msg.c_str(), error_pos);
    fprintf(stdout, "\n");
    fprintf(stdout, "Input string:\n");
    fprintf(stdout, "%s", input_str.substr(0, error_pos).c_str());
    if (error_pos < input_str.size()) {
        fprintf(stdout, "\033[1;31m%c", input_str[error_pos]);
        if (error_pos+1 < input_str.size()) {
            fprintf(stdout, "\033[0;31m%s", input_str.substr(error_pos+1).c_str());
        }
        fprintf(stdout, "\033[0m\n");
    }
}
int main(int argc, char** argv) {
    if (argc != 3) {
        fprintf(stdout, "Usage: %s <grammar_filename> <input_filename>\n", argv[0]);
        return 1;
    }
    const std::string grammar_filename = argv[1];
    const std::string input_filename = argv[2];
    // Read the GBNF grammar file
    FILE* grammar_file = fopen(grammar_filename.c_str(), "r");
    if (!grammar_file) {
        fprintf(stdout, "Failed to open grammar file: %s\n", grammar_filename.c_str());
        return 1;
    }
    fseek(grammar_file, 0, SEEK_END);
    size_t grammar_size = ftell(grammar_file);
    fseek(grammar_file, 0, SEEK_SET);
    std::string grammar_str(grammar_size, ' ');
    fread(&grammar_str[0], 1, grammar_size, grammar_file);
    fclose(grammar_file);
    // Parse the GBNF grammar
    auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
    // will be empty (default) if there are parse errors
    if (parsed_grammar.rules.empty()) {
        fprintf(stdout, "%s: failed to parse grammar\n", __func__);
        return 1;
    }
    // Ensure that there is a "root" node.
    if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) {
        fprintf(stdout, "%s: grammar does not contain a 'root' symbol\n", __func__);
        return 1;
    }
    std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
    // Create the LLAMA grammar
    auto grammar = llama_grammar_init(
            grammar_rules.data(),
            grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
    // Read the input file
    FILE* input_file = fopen(input_filename.c_str(), "r");
    if (!input_file) {
        fprintf(stdout, "Failed to open input file: %s\n", input_filename.c_str());
        return 1;
    }
    fseek(input_file, 0, SEEK_END);
    size_t input_size = ftell(input_file);
    fseek(input_file, 0, SEEK_SET);
    std::string input_str(input_size, ' ');
    fread(&input_str[0], 1, input_size, input_file);
    fclose(input_file);
    // Validate the input string against the grammar
    size_t error_pos;
    std::string error_msg;
    bool is_valid = llama_sample_grammar_string(grammar, input_str, error_pos, error_msg);
    if (is_valid) {
        fprintf(stdout, "Input string is valid according to the grammar.\n");
    } else {
        print_error_message(input_str, error_pos, error_msg);
    }
    // Clean up
    llama_grammar_free(grammar);
    return 0;
}
 |