summaryrefslogtreecommitdiff
path: root/examples/baby-llama/baby-llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/baby-llama/baby-llama.cpp')
-rw-r--r--examples/baby-llama/baby-llama.cpp18
1 files changed, 9 insertions, 9 deletions
diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp
index 8155101d..2dc2988d 100644
--- a/examples/baby-llama/baby-llama.cpp
+++ b/examples/baby-llama/baby-llama.cpp
@@ -1258,9 +1258,9 @@ static struct ggml_tensor * forward_lora(
}
static void sample_softmax(struct ggml_tensor * logits, struct ggml_tensor * probs, struct ggml_tensor * best_samples) {
- assert(logits->n_dims == 2);
- assert(probs->n_dims == 2);
- assert(best_samples->n_dims == 1);
+ assert(ggml_is_matrix(logits));
+ assert(ggml_is_matrix(probs));
+ assert(ggml_is_vector(best_samples));
assert(logits->ne[1] == best_samples->ne[0]);
assert(logits->ne[0] == probs->ne[0]);
assert(logits->ne[1] == probs->ne[1]);
@@ -1292,9 +1292,9 @@ static void sample_softmax_batch(
struct ggml_context * ctx, struct ggml_tensor * logits, struct ggml_tensor * probs,
struct ggml_tensor * best_samples
) {
- GGML_ASSERT(best_samples->n_dims == 2);
- GGML_ASSERT(logits->n_dims == 3);
- GGML_ASSERT(probs->n_dims == 3);
+ GGML_ASSERT(ggml_is_matrix(best_samples));
+ GGML_ASSERT(ggml_is_3d(logits));
+ GGML_ASSERT(ggml_is_3d(probs));
int n_tokens = best_samples->ne[0];
int n_batch = best_samples->ne[1];
int n_vocab = logits->ne[0];
@@ -1334,7 +1334,7 @@ static void print_row(struct ggml_tensor * probs, int i) {
}
static void print_matrix(struct ggml_tensor * probs) {
- assert(probs->n_dims == 2);
+ assert(ggml_is_matrix(probs));
for (int i = 0; i < probs->ne[1]; ++i) {
for (int k = 0; k < probs->ne[0]; ++k) {
float p = ggml_get_f32_1d(probs, i*probs->ne[0] + k);
@@ -1386,8 +1386,8 @@ static void get_example_targets(int example_id, struct ggml_tensor * tokens_inpu
static void get_example_targets_batch(
struct ggml_context * ctx, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets
) {
- GGML_ASSERT(tokens_input->n_dims == 2);
- GGML_ASSERT( targets->n_dims == 3);
+ GGML_ASSERT(ggml_is_matrix(tokens_input));
+ GGML_ASSERT(ggml_is_3d(targets));
int n_tokens = tokens_input->ne[0];
int n_batch = tokens_input->ne[1];
GGML_ASSERT(n_tokens == targets->ne[1]);