summaryrefslogtreecommitdiff
path: root/main.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'main.cpp')
-rw-r--r--main.cpp8
1 files changed, 5 insertions, 3 deletions
diff --git a/main.cpp b/main.cpp
index 982adf16..eca71408 100644
--- a/main.cpp
+++ b/main.cpp
@@ -400,7 +400,7 @@ bool llama_eval(
const int n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
- const int n_rot = hparams.n_rot;
+ const int n_rot = hparams.n_embd/hparams.n_head;
const int d_key = n_embd/n_head;
@@ -628,6 +628,9 @@ int main(int argc, char ** argv) {
params.prompt = gpt_random_prompt(rng);
}
+// params.prompt = R"(// this function checks if the number n is prime
+//bool is_prime(int n) {)";
+
int64_t t_load_us = 0;
gpt_vocab vocab;
@@ -691,7 +694,6 @@ int main(int argc, char ** argv) {
if (i >= embd_inp.size()) {
// sample next token
- const int top_k = params.top_k;
const float top_p = params.top_p;
const float temp = params.temp;
@@ -702,7 +704,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_sample_us = ggml_time_us();
- id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
+ id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_p, temp, rng);
t_sample_us += ggml_time_us() - t_start_sample_us;
}