summaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c30
1 files changed, 26 insertions, 4 deletions
diff --git a/ggml.c b/ggml.c
index f5caeba0..6dbd7626 100644
--- a/ggml.c
+++ b/ggml.c
@@ -2354,6 +2354,10 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
}
void ggml_free(struct ggml_context * ctx) {
+ if (ctx == NULL) {
+ return;
+ }
+
// make this function thread safe
ggml_critical_section_start();
@@ -4362,6 +4366,23 @@ struct ggml_tensor * ggml_cpy(
return ggml_cpy_impl(ctx, a, b);
}
+struct ggml_tensor * ggml_cast(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ enum ggml_type type) {
+ bool is_node = false;
+
+ struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
+ ggml_format_name(result, "%s (copy)", a->name);
+
+ result->op = GGML_OP_CPY;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = result;
+
+ return result;
+}
+
// ggml_cont
static struct ggml_tensor * ggml_cont_impl(
@@ -14871,7 +14892,7 @@ size_t ggml_hash_find_or_insert(struct ggml_hash_set hash_set, struct ggml_tenso
return i;
}
-static struct ggml_hash_set ggml_hash_set_new(size_t size) {
+struct ggml_hash_set ggml_hash_set_new(size_t size) {
size = ggml_hash_size(size);
struct ggml_hash_set result;
result.size = size;
@@ -16620,7 +16641,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
return GGML_EXIT_SUCCESS;
}
-struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
+struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threads) {
if (n_threads <= 0) {
n_threads = GGML_DEFAULT_N_THREADS;
}
@@ -16682,14 +16703,15 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
} break;
case GGML_OP_MUL_MAT_ID:
{
+ cur = 0;
const struct ggml_tensor * src0 = node->src[2];
const struct ggml_tensor * src1 = node->src[1];
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
if (src1->type != vec_dot_type) {
- cur = ggml_row_size(vec_dot_type, ggml_nelements(src1));
+ cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
}
const int n_as = ggml_get_op_params_i32(node, 1);
- cur = GGML_PAD(cur, sizeof(int64_t)); // align
+ cur += GGML_PAD(cur, sizeof(int64_t)); // align
cur += n_as * sizeof(int64_t); // matrix_row_counts
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
} break;