From e8a7fd4fb06d82f663850c21fcf86c0fb98ad9b4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 May 2024 19:09:30 +0300 Subject: metal : support FA without mask + add asserts (#7278) * ggml : fa without mask + add asserts ggml-ci * metal : support non-contiguous KV ggml-ci --- ggml.c | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'ggml.c') diff --git a/ggml.c b/ggml.c index d443a9b4..03b609dd 100644 --- a/ggml.c +++ b/ggml.c @@ -2824,6 +2824,16 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor (t0->ne[3] == t1->ne[3] ); } +bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->nb[0] == t1->nb[0] ) && + (t0->nb[1] == t1->nb[1] ) && + (t0->nb[2] == t1->nb[2] ) && + (t0->nb[3] == t1->nb[3] ); +} + // check if t1 can be represented as a repeatition of t0 static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); -- cgit v1.2.3