diff options
Diffstat (limited to 'examples/baby-llama/baby-llama.cpp')
-rw-r--r-- | examples/baby-llama/baby-llama.cpp | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 8155101d..2dc2988d 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -1258,9 +1258,9 @@ static struct ggml_tensor * forward_lora( } static void sample_softmax(struct ggml_tensor * logits, struct ggml_tensor * probs, struct ggml_tensor * best_samples) { - assert(logits->n_dims == 2); - assert(probs->n_dims == 2); - assert(best_samples->n_dims == 1); + assert(ggml_is_matrix(logits)); + assert(ggml_is_matrix(probs)); + assert(ggml_is_vector(best_samples)); assert(logits->ne[1] == best_samples->ne[0]); assert(logits->ne[0] == probs->ne[0]); assert(logits->ne[1] == probs->ne[1]); @@ -1292,9 +1292,9 @@ static void sample_softmax_batch( struct ggml_context * ctx, struct ggml_tensor * logits, struct ggml_tensor * probs, struct ggml_tensor * best_samples ) { - GGML_ASSERT(best_samples->n_dims == 2); - GGML_ASSERT(logits->n_dims == 3); - GGML_ASSERT(probs->n_dims == 3); + GGML_ASSERT(ggml_is_matrix(best_samples)); + GGML_ASSERT(ggml_is_3d(logits)); + GGML_ASSERT(ggml_is_3d(probs)); int n_tokens = best_samples->ne[0]; int n_batch = best_samples->ne[1]; int n_vocab = logits->ne[0]; @@ -1334,7 +1334,7 @@ static void print_row(struct ggml_tensor * probs, int i) { } static void print_matrix(struct ggml_tensor * probs) { - assert(probs->n_dims == 2); + assert(ggml_is_matrix(probs)); for (int i = 0; i < probs->ne[1]; ++i) { for (int k = 0; k < probs->ne[0]; ++k) { float p = ggml_get_f32_1d(probs, i*probs->ne[0] + k); @@ -1386,8 +1386,8 @@ static void get_example_targets(int example_id, struct ggml_tensor * tokens_inpu static void get_example_targets_batch( struct ggml_context * ctx, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets ) { - GGML_ASSERT(tokens_input->n_dims == 2); - GGML_ASSERT( targets->n_dims == 3); + GGML_ASSERT(ggml_is_matrix(tokens_input)); + GGML_ASSERT(ggml_is_3d(targets)); int n_tokens = tokens_input->ne[0]; int n_batch = tokens_input->ne[1]; GGML_ASSERT(n_tokens == targets->ne[1]); |