summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-05-14 19:09:30 +0300
committerGitHub <noreply@github.com>2024-05-14 19:09:30 +0300
commite8a7fd4fb06d82f663850c21fcf86c0fb98ad9b4 (patch)
treecf3e07d88d47f14717ae7ae6923b653950d87911 /ggml.c
parenta5e3fde8578d54b98d941344a4da150669af200d (diff)
metal : support FA without mask + add asserts (#7278)
* ggml : fa without mask + add asserts ggml-ci * metal : support non-contiguous KV ggml-ci
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c10
1 files changed, 10 insertions, 0 deletions
diff --git a/ggml.c b/ggml.c
index d443a9b4..03b609dd 100644
--- a/ggml.c
+++ b/ggml.c
@@ -2824,6 +2824,16 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor
(t0->ne[3] == t1->ne[3] );
}
+bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ (t0->nb[0] == t1->nb[0] ) &&
+ (t0->nb[1] == t1->nb[1] ) &&
+ (t0->nb[2] == t1->nb[2] ) &&
+ (t0->nb[3] == t1->nb[3] );
+}
+
// check if t1 can be represented as a repeatition of t0
static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");