From 424b6381c4daeed62e6600e0402e72f39845b58d Mon Sep 17 00:00:00 2001 From: slaren Date: Fri, 13 Oct 2023 12:23:10 +0200 Subject: ggml : add context enumeration functions (#3605) finetune : fix assert failure in ggml-alloc --- ggml.c | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) (limited to 'ggml.c') diff --git a/ggml.c b/ggml.c index c00ab00d..630deb49 100644 --- a/ggml.c +++ b/ggml.c @@ -5494,6 +5494,39 @@ struct ggml_tensor * ggml_view_tensor( return result; } +struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx) { + struct ggml_object * obj = ctx->objects_begin; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TENSOR) { + return (struct ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + +struct ggml_tensor * ggml_get_next_tensor(struct ggml_context * ctx, struct ggml_tensor * tensor) { + struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE); + obj = obj->next; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TENSOR) { + return (struct ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) { struct ggml_object * obj = ctx->objects_begin; @@ -8647,6 +8680,7 @@ void ggml_set_param( GGML_ASSERT(tensor->grad == NULL); tensor->grad = ggml_dup_tensor(ctx, tensor); + ggml_format_name(tensor->grad, "%s (grad)", tensor->name); } // ggml_compute_forward_dup -- cgit v1.2.3